mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-30 12:02:42 +00:00
Compare commits
70 Commits
cli/v0.2.0
...
config_cha
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
376a2aee52 | ||
|
|
8cac462353 | ||
|
|
1b849f6df6 | ||
|
|
d105b96c97 | ||
|
|
7cd555653f | ||
|
|
38fb3f96ec | ||
|
|
339a111a8f | ||
|
|
89696a7f57 | ||
|
|
09b7e6fc9b | ||
|
|
135238014f | ||
|
|
303e37bf53 | ||
|
|
6a888e9900 | ||
|
|
e90a7767c6 | ||
|
|
108e913017 | ||
|
|
1ded3af63c | ||
|
|
c53546c000 | ||
|
|
9afa12edda | ||
|
|
32046de962 | ||
|
|
bdf93704db | ||
|
|
81b4dbb5b0 | ||
|
|
1124a220cd | ||
|
|
f4e14bfa06 | ||
|
|
02d29a1a83 | ||
|
|
709261aec9 | ||
|
|
ec9c494a15 | ||
|
|
f160692d19 | ||
|
|
631dcf8222 | ||
|
|
334ae3ab17 | ||
|
|
6904ee5de0 | ||
|
|
4d913d7c3d | ||
|
|
d7f67d92b8 | ||
|
|
bc52fcc4ce | ||
|
|
6ec04dc446 | ||
|
|
2b90453748 | ||
|
|
db04ade260 | ||
|
|
da31f91631 | ||
|
|
dbaefae827 | ||
|
|
fd407e5657 | ||
|
|
65a18a720e | ||
|
|
a1ed112972 | ||
|
|
40221968a7 | ||
|
|
ae325a8b01 | ||
|
|
9579dd4be4 | ||
|
|
a461ca9710 | ||
|
|
767dba1096 | ||
|
|
44ed8632a5 | ||
|
|
8d12302378 | ||
|
|
0b00d56e18 | ||
|
|
590d844ea7 | ||
|
|
cfa0f6e4fa | ||
|
|
81829025b7 | ||
|
|
a24270c13b | ||
|
|
cd62834902 | ||
|
|
0e93bd2d6f | ||
|
|
5fde172464 | ||
|
|
0593107528 | ||
|
|
4e2b9e3a32 | ||
|
|
6b9870375c | ||
|
|
62ee239a1a | ||
|
|
2932fe948d | ||
|
|
07b2b4e0ff | ||
|
|
6b1005240c | ||
|
|
cd54511605 | ||
|
|
e556a13f46 | ||
|
|
acd81c99d4 | ||
|
|
d2eccbb0e0 | ||
|
|
2383caaccd | ||
|
|
3559e7f2bf | ||
|
|
d10895336c | ||
|
|
94873c8a8e |
@@ -0,0 +1,107 @@
|
||||
"""Generalise model config and differentiate via flow
|
||||
|
||||
Revision ID: 0c5fbcd15bdd
|
||||
Revises: f220515df7b4
|
||||
Create Date: 2026-01-26 14:43:16.932376
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0c5fbcd15bdd"
|
||||
down_revision = "f220515df7b4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add each model config to the text flow, setting the global default if it exists
|
||||
# Exclude models that are part of ImageGenerationConfig
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO flow_mapping (flow_type, is_default, model_configuration_id)
|
||||
SELECT
|
||||
'text' AS flow_type,
|
||||
COALESCE(
|
||||
(lp.is_default_provider IS TRUE AND lp.default_model_name = mc.name),
|
||||
FALSE
|
||||
) AS is_default,
|
||||
mc.id AS model_configuration_id
|
||||
FROM model_configuration mc
|
||||
LEFT JOIN llm_provider lp
|
||||
ON lp.id = mc.llm_provider_id
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM image_generation_config igc
|
||||
WHERE igc.model_configuration_id = mc.id
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add vision models to the vision flow
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO flow_mapping (flow_type, is_default, model_configuration_id)
|
||||
SELECT
|
||||
'vision' AS flow_type,
|
||||
COALESCE(
|
||||
(lp.is_default_vision_provider IS TRUE AND lp.default_vision_model = mc.name),
|
||||
FALSE
|
||||
) AS is_default,
|
||||
mc.id AS model_configuration_id
|
||||
FROM model_configuration mc
|
||||
LEFT JOIN llm_provider lp
|
||||
ON lp.id = mc.llm_provider_id;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Populate vision defaults from flow_mapping
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE llm_provider AS lp
|
||||
SET
|
||||
is_default_vision_provider = TRUE,
|
||||
default_vision_model = mc.name
|
||||
FROM flow_mapping fm
|
||||
JOIN model_configuration mc ON mc.id = fm.model_configuration_id
|
||||
WHERE fm.flow_type = 'vision'
|
||||
AND fm.is_default = TRUE
|
||||
AND mc.llm_provider_id = lp.id;
|
||||
"""
|
||||
)
|
||||
|
||||
# Populate text defaults from flow_mapping
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE llm_provider AS lp
|
||||
SET
|
||||
is_default_provider = TRUE,
|
||||
default_model_name = mc.name
|
||||
FROM flow_mapping fm
|
||||
JOIN model_configuration mc ON mc.id = fm.model_configuration_id
|
||||
WHERE fm.flow_type = 'text'
|
||||
AND fm.is_default = TRUE
|
||||
AND mc.llm_provider_id = lp.id;
|
||||
"""
|
||||
)
|
||||
|
||||
# For providers that have text flow mappings but aren't the default,
|
||||
# we still need a default_model_name (it was NOT NULL originally)
|
||||
# Pick the first visible model or any model for that provider
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE llm_provider AS lp
|
||||
SET default_model_name = (
|
||||
SELECT mc.name
|
||||
FROM model_configuration mc
|
||||
JOIN flow_mapping fm ON fm.model_configuration_id = mc.id
|
||||
WHERE mc.llm_provider_id = lp.id
|
||||
AND fm.flow_type = 'text'
|
||||
ORDER BY mc.is_visible DESC, mc.id ASC
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE lp.default_model_name IS NULL;
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,97 @@
|
||||
"""Persona uses model configuration id
|
||||
|
||||
Revision ID: 1b5455f99105
|
||||
Revises: e7f8a9b0c1d2
|
||||
Create Date: 2026-01-27 19:11:34.510574
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1b5455f99105"
|
||||
down_revision = "e7f8a9b0c1d2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# We need to migrate from llm_model_provider_override and llm_model_version_override to default_model_configuration_id
|
||||
# Migration Strategy:
|
||||
# 1. Port over where provider_override + model_version are both present
|
||||
# 2. Port over where only provider is present (Provider default)
|
||||
# 3. Where only the model is provided, or neither provider or model are provided, we go for global default
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Strategy 1: Both provider_override and model_version are present
|
||||
# Match against model_configuration using provider name and model name
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET default_model_configuration_id = mc.id
|
||||
FROM model_configuration mc
|
||||
JOIN llm_provider lp ON mc.llm_provider_id = lp.id
|
||||
WHERE persona.llm_model_provider_override IS NOT NULL
|
||||
AND persona.llm_model_version_override IS NOT NULL
|
||||
AND lp.name = persona.llm_model_provider_override
|
||||
AND mc.name = persona.llm_model_version_override
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Strategy 2: Only provider is present (use Provider's default model)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET default_model_configuration_id = mc.id
|
||||
FROM model_configuration mc
|
||||
JOIN llm_provider lp ON mc.llm_provider_id = lp.id
|
||||
WHERE persona.llm_model_provider_override IS NOT NULL
|
||||
AND persona.llm_model_version_override IS NULL
|
||||
AND persona.default_model_configuration_id IS NULL
|
||||
AND lp.name = persona.llm_model_provider_override
|
||||
AND mc.name = lp.default_model_name
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# For remaining personas with only model set but no match found,
|
||||
# fall back to global default (provider's default model where is_default_provider = true)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET default_model_configuration_id = mc.id
|
||||
FROM model_configuration mc
|
||||
JOIN llm_provider lp ON mc.llm_provider_id = lp.id
|
||||
WHERE persona.llm_model_provider_override IS NULL
|
||||
AND persona.llm_model_version_override IS NOT NULL
|
||||
AND persona.default_model_configuration_id IS NULL
|
||||
AND lp.is_default_provider = true
|
||||
AND mc.name = lp.default_model_name
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Migrate data back from default_model_configuration_id to old columns
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET llm_model_provider_override = lp.name,
|
||||
llm_model_version_override = mc.name
|
||||
FROM model_configuration mc
|
||||
JOIN llm_provider lp ON mc.llm_provider_id = lp.id
|
||||
WHERE persona.default_model_configuration_id IS NOT NULL
|
||||
AND persona.default_model_configuration_id = mc.id
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Add flow mapping table
|
||||
|
||||
Revision ID: f220515df7b4
|
||||
Revises: be87a654d5af
|
||||
Create Date: 2026-01-30 12:21:24.955922
|
||||
|
||||
"""
|
||||
|
||||
from onyx.db.enums import ModelFlowType
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f220515df7b4"
|
||||
down_revision = "be87a654d5af"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"model_flow",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"model_flow_type",
|
||||
sa.Enum(ModelFlowType, name="modelflowtype", native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_default", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column("model_configuration_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["model_configuration_id"], ["model_configuration.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"model_flow_type",
|
||||
"model_configuration_id",
|
||||
name="uq_model_config_per_flow_type",
|
||||
),
|
||||
)
|
||||
|
||||
# Partial unique index so that there is at most one default for each flow type
|
||||
op.create_index(
|
||||
"ix_one_default_per_model_flow",
|
||||
"model_flow",
|
||||
["model_flow_type"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_default IS TRUE"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the model_flow table (index is dropped automatically with table)
|
||||
op.drop_table("model_flow")
|
||||
@@ -20,7 +20,7 @@ from ee.onyx.server.enterprise_settings.store import (
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_text_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
@@ -123,9 +123,12 @@ def _seed_llms(
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
if len(seeded_providers[0].model_configurations) > 0:
|
||||
update_default_text_provider(
|
||||
provider_id=seeded_providers[0].id,
|
||||
model=seeded_providers[0].model_configurations[0].name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
@@ -144,8 +147,7 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
default_model_configuration_id=persona.default_model_configuration_id,
|
||||
starter_messages=persona.starter_messages,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -33,7 +33,7 @@ from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_text_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import AvailableTenant
|
||||
@@ -300,12 +300,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_text_provider(
|
||||
provider_id=provider.id, model=default_model, db_session=db_session
|
||||
)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -330,7 +332,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -366,7 +368,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -398,7 +400,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -435,7 +437,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -293,8 +293,7 @@ def create_temporary_persona(
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
default_model_configuration_id=persona_config.default_model_configuration_id,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
|
||||
@@ -118,8 +118,7 @@ class PersonaOverrideConfig(BaseModel):
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
default_model_configuration_id: int | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
# Note: prompt_ids removed - prompts are now embedded in personas
|
||||
|
||||
@@ -269,3 +269,9 @@ class HierarchyNodeType(str, PyEnum):
|
||||
|
||||
# Slack
|
||||
CHANNEL = "channel"
|
||||
|
||||
|
||||
class ModelFlowType(str, PyEnum):
|
||||
CONVERSATION = "conversation"
|
||||
VISION = "vision"
|
||||
EMBEDDINGS = "embeddings"
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import ModelFlowType
|
||||
from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import ImageGenerationConfig
|
||||
@@ -11,6 +15,7 @@ from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import LLMProvider__Persona
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import ModelFlow
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import Tool as ToolModel
|
||||
@@ -20,8 +25,10 @@ from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationView
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
@@ -163,15 +170,16 @@ def validate_persona_ids_exist(
|
||||
|
||||
|
||||
def get_personas_using_provider(
|
||||
db_session: Session, provider_name: str
|
||||
db_session: Session, provider: LLMProviderModel
|
||||
) -> list[Persona]:
|
||||
"""Get all non-deleted personas that use a specific LLM provider."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(Persona).where(
|
||||
Persona.llm_model_provider_override == provider_name,
|
||||
Persona.deleted == False, # noqa: E712
|
||||
select(Persona)
|
||||
.join(
|
||||
ModelConfiguration,
|
||||
Persona.default_model_configuration_id == ModelConfiguration.id,
|
||||
)
|
||||
.where(ModelConfiguration.llm_provider_id == provider.id)
|
||||
).all()
|
||||
)
|
||||
|
||||
@@ -233,9 +241,6 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
@@ -263,17 +268,14 @@ def upsert_llm_provider(
|
||||
model_provider=llm_provider_upsert_request.provider,
|
||||
)
|
||||
|
||||
db_session.execute(
|
||||
insert(ModelConfiguration)
|
||||
.values(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
name=model_configuration.name,
|
||||
is_visible=model_configuration.is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
supports_image_input=model_configuration.supports_image_input,
|
||||
display_name=model_configuration.display_name,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
model_name=model_configuration.name,
|
||||
is_visible=model_configuration.is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
supports_image_input=model_configuration.supports_image_input or False,
|
||||
display_name=model_configuration.display_name,
|
||||
)
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
@@ -331,23 +333,19 @@ def sync_model_configurations(
|
||||
model_name = model["name"]
|
||||
if model_name not in existing_names:
|
||||
# Insert new model with is_visible=False (user must explicitly enable)
|
||||
db_session.execute(
|
||||
insert(ModelConfiguration)
|
||||
.values(
|
||||
llm_provider_id=provider.id,
|
||||
name=model_name,
|
||||
is_visible=False,
|
||||
max_input_tokens=model.get("max_input_tokens"),
|
||||
supports_image_input=model.get("supports_image_input", False),
|
||||
display_name=model.get("display_name"),
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=provider.id,
|
||||
model_name=model_name,
|
||||
is_visible=False,
|
||||
max_input_tokens=model.get("max_input_tokens"),
|
||||
supports_image_input=model.get("supports_image_input", False),
|
||||
display_name=model.get("display_name"),
|
||||
)
|
||||
new_count += 1
|
||||
|
||||
if new_count > 0:
|
||||
db_session.commit()
|
||||
|
||||
return new_count
|
||||
|
||||
|
||||
@@ -371,38 +369,85 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
|
||||
)
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
def fetch_existing_llm_providers_supporting_flows(
|
||||
db_session: Session,
|
||||
flows: list[ModelFlowType],
|
||||
only_public: bool = False,
|
||||
exclude_image_generation_providers: bool = True,
|
||||
) -> list[LLMProviderModel]:
|
||||
"""Fetch all LLM providers with optional filtering.
|
||||
"""Fetch LLM providers that have at least one model supporting one of the specified flows.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
flows: List of flow types to filter by (providers must have at least one model
|
||||
with a ModelFlow matching any of these flow types)
|
||||
only_public: If True, only return public providers
|
||||
exclude_image_generation_providers: If True, exclude providers that are
|
||||
used for image generation configs
|
||||
exclude_image_generation_providers: If True, only match on flows.
|
||||
If False, include image generation providers in addition to
|
||||
flow-filtered providers.
|
||||
|
||||
Returns:
|
||||
List of LLM providers matching the criteria
|
||||
"""
|
||||
stmt = select(LLMProviderModel).options(
|
||||
# Subquery to find provider IDs that have at least one model supporting any of the flows
|
||||
providers_with_flows = (
|
||||
select(ModelConfiguration.llm_provider_id)
|
||||
.join(ModelFlow)
|
||||
.where(ModelFlow.model_flow_type.in_(flows))
|
||||
.distinct()
|
||||
)
|
||||
|
||||
if exclude_image_generation_providers:
|
||||
# Only include providers with matching flows
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
)
|
||||
else:
|
||||
# Include providers with matching flows OR image generation providers
|
||||
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
|
||||
ImageGenerationConfig
|
||||
)
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
| LLMProviderModel.id.in_(image_gen_provider_ids)
|
||||
)
|
||||
|
||||
stmt = stmt.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
|
||||
if exclude_image_generation_providers:
|
||||
# Get LLM provider IDs used by ImageGenerationConfig
|
||||
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
|
||||
ImageGenerationConfig
|
||||
)
|
||||
stmt = stmt.where(LLMProviderModel.id.not_in(image_gen_provider_ids))
|
||||
|
||||
providers = list(db_session.scalars(stmt).all())
|
||||
|
||||
if only_public:
|
||||
return [provider for provider in providers if provider.is_public]
|
||||
|
||||
return providers
|
||||
|
||||
|
||||
def fetch_existing_model_configs_for_flow(
|
||||
db_session: Session,
|
||||
flows: list[ModelFlowType],
|
||||
) -> list[ModelConfiguration]:
|
||||
"""Fetch all model configurations that support any of the specified flow types.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
flows: List of flow types to filter by
|
||||
|
||||
Returns:
|
||||
List of model configurations that have a ModelFlow matching any of the flows
|
||||
"""
|
||||
stmt = (
|
||||
select(ModelConfiguration)
|
||||
.join(ModelFlow)
|
||||
.where(ModelFlow.model_flow_type.in_(flows))
|
||||
.options(selectinload(ModelConfiguration.llm_provider))
|
||||
)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_provider(
|
||||
name: str, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
@@ -419,6 +464,12 @@ def fetch_existing_llm_provider(
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_provider_by_id(
|
||||
id: int, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
return db_session.scalar(select(LLMProviderModel).where(LLMProviderModel.id == id))
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
@@ -429,26 +480,23 @@ def fetch_embedding_provider(
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.is_default_provider == True) # noqa: E712
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
def fetch_default_model(
|
||||
db_session: Session, flow_type: ModelFlowType
|
||||
) -> DefaultModel | None:
|
||||
flow_mapping = db_session.scalar(
|
||||
select(ModelFlow).where(
|
||||
ModelFlow.model_flow_type == flow_type,
|
||||
ModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
if not provider_model:
|
||||
|
||||
if not flow_mapping:
|
||||
return None
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.is_default_vision_provider == True) # noqa: E712
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
return DefaultModel(
|
||||
provider_id=flow_mapping.model_configuration.llm_provider.id,
|
||||
model_name=flow_mapping.model_configuration.name,
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_llm_provider_view(
|
||||
@@ -462,6 +510,48 @@ def fetch_llm_provider_view(
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_llm_provider_view_from_id(
|
||||
db_session: Session, provider_id: int
|
||||
) -> LLMProviderView | None:
|
||||
provider_model = fetch_existing_llm_provider_by_id(
|
||||
id=provider_id, db_session=db_session
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_model_configuration_view(
|
||||
db_session: Session, model_configuration_id: int
|
||||
) -> ModelConfigurationView | None:
|
||||
model_configuration_model = db_session.scalar(
|
||||
select(ModelConfiguration).where(
|
||||
ModelConfiguration.id == model_configuration_id
|
||||
)
|
||||
)
|
||||
|
||||
if not model_configuration_model:
|
||||
return None
|
||||
|
||||
return ModelConfigurationView.from_model(
|
||||
model_configuration_model=model_configuration_model,
|
||||
provider=model_configuration_model.llm_provider.provider,
|
||||
)
|
||||
|
||||
|
||||
def fetch_llm_provider_view_from_model_id(
|
||||
db_session: Session, model_configuration_id: int
|
||||
) -> LLMProviderView | None:
|
||||
model_configuration_model = db_session.scalar(
|
||||
select(ModelConfiguration).where(
|
||||
ModelConfiguration.id == model_configuration_id
|
||||
)
|
||||
)
|
||||
if not model_configuration_model:
|
||||
return None
|
||||
return LLMProviderView.from_model(model_configuration_model.llm_provider)
|
||||
|
||||
|
||||
def remove_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> None:
|
||||
@@ -479,16 +569,18 @@ def remove_embedding_provider(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
def remove_llm_provider(
|
||||
db_session: Session, provider_id: int, commit: bool = True
|
||||
) -> None:
|
||||
provider = db_session.get(LLMProviderModel, provider_id)
|
||||
if not provider:
|
||||
raise ValueError("LLM Provider not found")
|
||||
|
||||
# Clear the provider override from any personas using it
|
||||
# This causes them to fall back to the default provider
|
||||
personas_using_provider = get_personas_using_provider(db_session, provider.name)
|
||||
personas_using_provider = get_personas_using_provider(db_session, provider)
|
||||
for persona in personas_using_provider:
|
||||
persona.llm_model_provider_override = None
|
||||
persona.default_model_configuration_id = None
|
||||
|
||||
db_session.execute(
|
||||
delete(LLMProvider__UserGroup).where(
|
||||
@@ -499,88 +591,86 @@ def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> None:
|
||||
"""Remove LLM provider."""
|
||||
provider = db_session.get(LLMProviderModel, provider_id)
|
||||
if not provider:
|
||||
raise ValueError("LLM Provider not found")
|
||||
|
||||
# Clear the provider override from any personas using it
|
||||
# This causes them to fall back to the default provider
|
||||
personas_using_provider = get_personas_using_provider(db_session, provider.name)
|
||||
for persona in personas_using_provider:
|
||||
persona.llm_model_provider_override = None
|
||||
|
||||
db_session.execute(
|
||||
delete(LLMProvider__UserGroup).where(
|
||||
LLMProvider__UserGroup.llm_provider_id == provider_id
|
||||
)
|
||||
)
|
||||
# Remove LLMProvider
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
new_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
if not new_default:
|
||||
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
|
||||
|
||||
existing_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if existing_default:
|
||||
existing_default.is_default_provider = None
|
||||
# required to ensure that the below does not cause a unique constraint violation
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
db_session.flush()
|
||||
|
||||
new_default.is_default_provider = True
|
||||
|
||||
def _update_default_provider(
|
||||
provider_id: int,
|
||||
model: str,
|
||||
db_session: Session,
|
||||
flow_type: ModelFlowType,
|
||||
validation_func: (
|
||||
Callable[[ModelConfiguration], str] | None
|
||||
) = None, # Returns error message
|
||||
) -> None:
|
||||
# Single query with join to get both model_config and its flow mapping
|
||||
result = db_session.execute(
|
||||
select(ModelConfiguration, ModelFlow)
|
||||
.join(ModelFlow, ModelFlow.model_configuration_id == ModelConfiguration.id)
|
||||
.where(
|
||||
ModelConfiguration.llm_provider_id == provider_id,
|
||||
ModelConfiguration.name == model,
|
||||
ModelFlow.model_flow_type == flow_type,
|
||||
)
|
||||
).first()
|
||||
|
||||
if not result:
|
||||
raise ValueError(
|
||||
f"Model '{model}' is not a valid TEXT model for provider '{provider_id}'"
|
||||
)
|
||||
|
||||
model_config, new_default = result
|
||||
|
||||
if validation_func and (error_message := validation_func(model_config)):
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Clear existing default and set now in a single atomic operation
|
||||
db_session.execute(
|
||||
update(ModelFlow)
|
||||
.where(
|
||||
ModelFlow.model_flow_type == flow_type,
|
||||
ModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
.values(is_default=False)
|
||||
)
|
||||
|
||||
new_default.is_default = True
|
||||
model_config.is_visible = True
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_default_text_provider(
|
||||
provider_id: int, model: str, db_session: Session
|
||||
) -> None:
|
||||
_update_default_provider(
|
||||
provider_id=provider_id,
|
||||
model=model,
|
||||
db_session=db_session,
|
||||
flow_type=ModelFlowType.CONVERSATION,
|
||||
)
|
||||
|
||||
|
||||
def update_default_vision_provider(
|
||||
provider_id: int, vision_model: str | None, db_session: Session
|
||||
provider_id: int, vision_model: str, db_session: Session
|
||||
) -> None:
|
||||
new_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
def validate_vision_model(model_config: ModelConfiguration) -> str:
|
||||
if not model_supports_image_input(
|
||||
vision_model, model_config.llm_provider.provider
|
||||
):
|
||||
return f"Model '{vision_model}' for provider '{model_config.llm_provider.provider}' does not support image input"
|
||||
return ""
|
||||
|
||||
_update_default_provider(
|
||||
provider_id=provider_id,
|
||||
model=vision_model,
|
||||
db_session=db_session,
|
||||
flow_type=ModelFlowType.VISION,
|
||||
validation_func=validate_vision_model,
|
||||
)
|
||||
if not new_default:
|
||||
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
|
||||
|
||||
# Validate that the specified vision model supports image input
|
||||
model_to_validate = vision_model or new_default.default_model_name
|
||||
if model_to_validate:
|
||||
if not model_supports_image_input(model_to_validate, new_default.provider):
|
||||
raise ValueError(
|
||||
f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'"
|
||||
)
|
||||
|
||||
existing_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if existing_default:
|
||||
existing_default.is_default_vision_provider = None
|
||||
# required to ensure that the below does not cause a unique constraint violation
|
||||
db_session.flush()
|
||||
|
||||
new_default.is_default_vision_provider = True
|
||||
new_default.default_vision_model = vision_model
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
@@ -671,10 +761,84 @@ def sync_auto_mode_models(
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
# Check if this provider is the set default provider
|
||||
default_model = fetch_default_model(db_session, ModelFlowType.CONVERSATION)
|
||||
recommended = llm_recommendations.get_default_model(provider.provider)
|
||||
if (
|
||||
default_model
|
||||
and default_model.provider_id == provider.id
|
||||
and recommended
|
||||
and recommended.name != default_model.model_name
|
||||
):
|
||||
update_default_text_provider(provider.id, recommended.name, db_session)
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
|
||||
def create_new_flow_mapping__no_commit(
|
||||
db_session: Session,
|
||||
model_configuration_id: int,
|
||||
flow_type: ModelFlowType,
|
||||
is_default: bool = False,
|
||||
) -> ModelFlow:
|
||||
"""Create a new flow mapping without committing.
|
||||
|
||||
Uses flush() to get the ID while keeping the transaction open,
|
||||
allowing callers like upsert_llm_provider to maintain atomicity.
|
||||
"""
|
||||
flow_mapping = ModelFlow(
|
||||
model_configuration_id=model_configuration_id,
|
||||
flow_type=flow_type,
|
||||
is_default=is_default,
|
||||
)
|
||||
|
||||
db_session.add(flow_mapping)
|
||||
db_session.flush()
|
||||
db_session.refresh(flow_mapping)
|
||||
return flow_mapping
|
||||
|
||||
|
||||
def insert_new_model_configuration__no_commit(
|
||||
db_session: Session,
|
||||
llm_provider_id: int,
|
||||
model_name: str,
|
||||
is_visible: bool,
|
||||
max_input_tokens: int | None,
|
||||
supports_image_input: bool,
|
||||
display_name: str | None,
|
||||
) -> int | None:
|
||||
result = db_session.execute(
|
||||
insert(ModelConfiguration)
|
||||
.values(
|
||||
llm_provider_id=llm_provider_id,
|
||||
name=model_name,
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=display_name,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
.returning(ModelConfiguration.id)
|
||||
)
|
||||
|
||||
model_config_id = result.scalar_one_or_none()
|
||||
|
||||
if not model_config_id:
|
||||
return None
|
||||
|
||||
create_new_flow_mapping__no_commit(
|
||||
db_session=db_session,
|
||||
model_configuration_id=model_config_id,
|
||||
flow_type=ModelFlowType.CONVERSATION,
|
||||
is_default=False,
|
||||
)
|
||||
|
||||
if supports_image_input:
|
||||
create_new_flow_mapping__no_commit(
|
||||
db_session=db_session,
|
||||
model_configuration_id=model_config_id,
|
||||
flow_type=ModelFlowType.VISION,
|
||||
is_default=False,
|
||||
)
|
||||
|
||||
return model_config_id
|
||||
|
||||
@@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import validates
|
||||
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
from pydantic import ValidationError
|
||||
@@ -70,6 +71,7 @@ from onyx.db.enums import (
|
||||
MCPAuthenticationPerformer,
|
||||
MCPTransport,
|
||||
MCPServerStatus,
|
||||
ModelFlowType,
|
||||
ThemePreference,
|
||||
SwitchoverType,
|
||||
)
|
||||
@@ -2613,14 +2615,16 @@ class LLMProvider(Base):
|
||||
custom_config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
# Deprecated: use ModelFlow.is_default
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# should only be set for a single provider
|
||||
# Deprecated: Use ModelFlows instead
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
# Auto mode: models, visibility, and defaults are managed by GitHub config
|
||||
@@ -2670,6 +2674,7 @@ class ModelConfiguration(Base):
|
||||
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
|
||||
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Deprecated: Use ModelFlows instead
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
# Human-readable display name for the model.
|
||||
@@ -2682,6 +2687,51 @@ class ModelConfiguration(Base):
|
||||
back_populates="model_configurations",
|
||||
)
|
||||
|
||||
model_flows: Mapped[list["ModelFlow"]] = relationship(
|
||||
"ModelFlow",
|
||||
back_populates="model_configuration",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_flow_types(self) -> list[ModelFlowType]:
|
||||
return [flow.model_flow_type for flow in self.model_flows]
|
||||
|
||||
|
||||
class ModelFlow(Base):
|
||||
__tablename__ = "model_flow"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
|
||||
model_flow_type: Mapped[ModelFlowType] = mapped_column(
|
||||
Enum(ModelFlowType), nullable=False
|
||||
)
|
||||
model_configuration_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("model_configuration.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
model_configuration: Mapped["ModelConfiguration"] = relationship(
|
||||
"ModelConfiguration",
|
||||
back_populates="model_flows",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"model_flow_type",
|
||||
"model_configuration_id",
|
||||
name="uq_model_config_per_flow_type",
|
||||
),
|
||||
Index(
|
||||
"ix_one_default_per_model_flow",
|
||||
"model_flow_type",
|
||||
unique=True,
|
||||
postgresql_where=(is_default == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationConfig(Base):
|
||||
__tablename__ = "image_generation_config"
|
||||
@@ -3013,6 +3063,12 @@ class Persona(Base):
|
||||
# Allows the persona to specify a specific default LLM model
|
||||
# NOTE: only is applied on the actual response generation - is not used for things like
|
||||
# auto-detected time filters, relevance filters, etc.
|
||||
default_model_configuration_id: Mapped[int | None] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("model_configuration.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
# llm_model_provider_override and llm_model_version_override are deprecated and will be removed in a future release
|
||||
llm_model_provider_override: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
)
|
||||
|
||||
@@ -284,8 +284,7 @@ def create_update_persona(
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
is_public=create_persona_request.is_public,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
default_model_configuration_id=create_persona_request.default_model_configuration_id,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
system_prompt=create_persona_request.system_prompt,
|
||||
task_prompt=create_persona_request.task_prompt,
|
||||
@@ -817,8 +816,7 @@ def upsert_persona(
|
||||
llm_relevance_filter: bool,
|
||||
llm_filter_extraction: bool,
|
||||
recency_bias: RecencyBiasSetting,
|
||||
llm_model_provider_override: str | None,
|
||||
llm_model_version_override: str | None,
|
||||
default_model_configuration_id: int | None,
|
||||
starter_messages: list[StarterMessage] | None,
|
||||
# Embedded prompt fields
|
||||
system_prompt: str | None,
|
||||
@@ -958,8 +956,7 @@ def upsert_persona(
|
||||
existing_persona.llm_relevance_filter = llm_relevance_filter
|
||||
existing_persona.llm_filter_extraction = llm_filter_extraction
|
||||
existing_persona.recency_bias = recency_bias
|
||||
existing_persona.llm_model_provider_override = llm_model_provider_override
|
||||
existing_persona.llm_model_version_override = llm_model_version_override
|
||||
existing_persona.default_model_configuration_id = default_model_configuration_id
|
||||
existing_persona.starter_messages = starter_messages
|
||||
existing_persona.deleted = False # Un-delete if previously deleted
|
||||
existing_persona.is_public = is_public
|
||||
@@ -1032,8 +1029,7 @@ def upsert_persona(
|
||||
datetime_aware=(datetime_aware if datetime_aware is not None else True),
|
||||
replace_base_system_prompt=replace_base_system_prompt,
|
||||
document_sets=document_sets or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
default_model_configuration_id=default_model_configuration_id,
|
||||
starter_messages=starter_messages,
|
||||
tools=tools or [],
|
||||
uploaded_image_id=uploaded_image_id,
|
||||
|
||||
@@ -79,8 +79,7 @@ def create_slack_channel_persona(
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
tool_ids=[search_tool.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
default_model_configuration_id=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
is_default_persona=False,
|
||||
|
||||
@@ -5,22 +5,24 @@ from sqlalchemy.orm import Session
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_default_vision_provider
|
||||
from onyx.db.llm import fetch_default_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_model_configs_for_flow
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import fetch_llm_provider_view_from_id
|
||||
from onyx.db.llm import fetch_llm_provider_view_from_model_id
|
||||
from onyx.db.llm import fetch_model_configuration_view
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.multi_llm import LitellmLLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
@@ -53,54 +55,6 @@ def _build_provider_extra_headers(
|
||||
return {}
|
||||
|
||||
|
||||
def get_llm_config_for_persona(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
llm_override: LLMOverride | None = None,
|
||||
) -> LLMConfig:
|
||||
"""Get LLM config from persona without access checks.
|
||||
|
||||
This function assumes access to the persona has already been verified.
|
||||
Use this when you need the LLM config but don't need to create the full LLM object.
|
||||
"""
|
||||
provider_name_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
|
||||
provider_name = provider_name_override or persona.llm_model_provider_override
|
||||
if not provider_name:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
if not llm_provider:
|
||||
raise ValueError("No default LLM provider found")
|
||||
model_name: str | None = llm_provider.default_model_name
|
||||
else:
|
||||
llm_provider = fetch_llm_provider_view(db_session, provider_name)
|
||||
if not llm_provider:
|
||||
raise ValueError(f"No LLM provider found with name: {provider_name}")
|
||||
model_name = model_version_override or persona.llm_model_version_override
|
||||
if not model_name:
|
||||
model_name = llm_provider.default_model_name
|
||||
|
||||
if not model_name:
|
||||
raise ValueError("No model name found")
|
||||
|
||||
max_input_tokens = get_max_input_tokens_from_llm_provider(
|
||||
llm_provider=llm_provider, model_name=model_name
|
||||
)
|
||||
|
||||
return LLMConfig(
|
||||
model_provider=llm_provider.provider,
|
||||
model_name=model_name,
|
||||
temperature=temperature_override or GEN_AI_TEMPERATURE,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
custom_config=llm_provider.custom_config,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
def get_llm_for_persona(
|
||||
persona: Persona | PersonaOverrideConfig | None,
|
||||
user: User,
|
||||
@@ -108,6 +62,11 @@ def get_llm_for_persona(
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
"""Get the appropriate LLM for a persona, with the following priority:
|
||||
1. LLM override (provider + model version)
|
||||
2. Persona's model configuration override
|
||||
3. Default LLM
|
||||
"""
|
||||
if persona is None:
|
||||
logger.warning("No persona provided, using default LLM")
|
||||
return get_default_llm()
|
||||
@@ -116,8 +75,7 @@ def get_llm_for_persona(
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
|
||||
provider_name = provider_name_override or persona.llm_model_provider_override
|
||||
if not provider_name:
|
||||
if not provider_name_override and not persona.default_model_configuration_id:
|
||||
return get_default_llm(
|
||||
temperature=temperature_override or GEN_AI_TEMPERATURE,
|
||||
additional_headers=additional_headers,
|
||||
@@ -125,7 +83,11 @@ def get_llm_for_persona(
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
provider_model = fetch_existing_llm_provider(provider_name, db_session)
|
||||
# Resolve the provider model
|
||||
# Atleast one of the vars in non-None due to above check, so we should get something
|
||||
provider_model = _resolve_provider_model(
|
||||
db_session, provider_name_override, persona.default_model_configuration_id
|
||||
)
|
||||
if not provider_model:
|
||||
raise ValueError("No LLM provider found")
|
||||
|
||||
@@ -155,7 +117,14 @@ def get_llm_for_persona(
|
||||
|
||||
llm_provider = LLMProviderView.from_model(provider_model)
|
||||
|
||||
model = model_version_override or persona.llm_model_version_override
|
||||
# Resolve the model name
|
||||
model = model_version_override
|
||||
if not model and persona.default_model_configuration_id:
|
||||
model_config = fetch_model_configuration_view(
|
||||
db_session, persona.default_model_configuration_id
|
||||
)
|
||||
model = model_config.name if model_config else None
|
||||
|
||||
if not model:
|
||||
raise ValueError("No model name found")
|
||||
|
||||
@@ -176,6 +145,25 @@ def get_llm_for_persona(
|
||||
)
|
||||
|
||||
|
||||
def _resolve_provider_model(
|
||||
db_session: Session,
|
||||
provider_name_override: str | None,
|
||||
model_config_id_override: int | None,
|
||||
) -> LLMProvider | None:
|
||||
"""Resolve the LLM provider model from overrides."""
|
||||
if provider_name_override:
|
||||
return fetch_existing_llm_provider(provider_name_override, db_session)
|
||||
|
||||
if model_config_id_override:
|
||||
llm_view = fetch_llm_provider_view_from_model_id(
|
||||
db_session, model_config_id_override
|
||||
)
|
||||
if llm_view:
|
||||
return fetch_existing_llm_provider(llm_view.name, db_session)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_default_llm_with_vision(
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
@@ -210,48 +198,40 @@ def get_default_llm_with_vision(
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Try the default vision provider first
|
||||
default_provider = fetch_default_vision_provider(db_session)
|
||||
if default_provider and default_provider.default_vision_model:
|
||||
if model_supports_image_input(
|
||||
default_provider.default_vision_model, default_provider.provider
|
||||
):
|
||||
return create_vision_llm(
|
||||
default_provider, default_provider.default_vision_model
|
||||
)
|
||||
default_model = fetch_default_model(
|
||||
db_session=db_session, flow_type=ModelFlowType.VISION
|
||||
)
|
||||
if default_model:
|
||||
provider_view = fetch_llm_provider_view_from_id(
|
||||
db_session, default_model.provider_id
|
||||
)
|
||||
if provider_view:
|
||||
return create_vision_llm(provider_view, default_model.model_name)
|
||||
# Fall back to searching all vision models
|
||||
models = fetch_existing_model_configs_for_flow(
|
||||
db_session=db_session,
|
||||
flows=[ModelFlowType.VISION],
|
||||
)
|
||||
|
||||
# Fall back to searching all providers
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
|
||||
if not providers:
|
||||
if not models:
|
||||
return None
|
||||
|
||||
# Check all providers for viable vision models
|
||||
for provider in providers:
|
||||
provider_view = LLMProviderView.from_model(provider)
|
||||
# Check for viable vision models
|
||||
non_public_vision_llm: LLM | None = None
|
||||
|
||||
# First priority: Check if provider has a default_vision_model
|
||||
if provider.default_vision_model and model_supports_image_input(
|
||||
provider.default_vision_model, provider.provider
|
||||
):
|
||||
return create_vision_llm(provider_view, provider.default_vision_model)
|
||||
for model in models:
|
||||
if model.is_visible:
|
||||
return create_vision_llm(
|
||||
provider=LLMProviderView.from_model(model.llm_provider),
|
||||
model=model.name,
|
||||
)
|
||||
elif not non_public_vision_llm:
|
||||
non_public_vision_llm = create_vision_llm(
|
||||
provider=LLMProviderView.from_model(model.llm_provider),
|
||||
model=model.name,
|
||||
)
|
||||
|
||||
# If no model-configurations are specified, try default model
|
||||
if not provider.model_configurations:
|
||||
# Try default_model_name
|
||||
if provider.default_model_name and model_supports_image_input(
|
||||
provider.default_model_name, provider.provider
|
||||
):
|
||||
return create_vision_llm(provider_view, provider.default_model_name)
|
||||
|
||||
# Otherwise, if model-configurations are specified, check each model
|
||||
else:
|
||||
for model_configuration in provider.model_configurations:
|
||||
if model_supports_image_input(
|
||||
model_configuration.name, provider.provider
|
||||
):
|
||||
return create_vision_llm(provider_view, model_configuration.name)
|
||||
|
||||
return None
|
||||
return non_public_vision_llm
|
||||
|
||||
|
||||
def llm_from_provider(
|
||||
@@ -298,18 +278,25 @@ def get_default_llm(
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
default_model = fetch_default_model(
|
||||
db_session=db_session, flow_type=ModelFlowType.CONVERSATION
|
||||
)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No default LLM provider found")
|
||||
if not default_model:
|
||||
raise ValueError("No default LLM provider found")
|
||||
|
||||
model_name = llm_provider.default_model_name
|
||||
if not model_name:
|
||||
raise ValueError("No default model name found")
|
||||
llm_provider_view = fetch_llm_provider_view_from_id(
|
||||
db_session, default_model.provider_id
|
||||
)
|
||||
|
||||
if not llm_provider_view:
|
||||
raise ValueError(
|
||||
"No LLM provider found with id {}".format(default_model.provider_id)
|
||||
)
|
||||
|
||||
return llm_from_provider(
|
||||
model_name=model_name,
|
||||
llm_provider=llm_provider,
|
||||
model_name=default_model.model_name,
|
||||
llm_provider=llm_provider_view,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ModelFlowType
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
@@ -689,8 +690,8 @@ def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
LLMProvider.provider == model_provider,
|
||||
)
|
||||
)
|
||||
if model_config and model_config.supports_image_input is not None:
|
||||
return model_config.supports_image_input
|
||||
if model_config:
|
||||
return ModelFlowType.VISION in model_config.model_flow_types
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to query database for {model_provider} model {model_name} image support: {e}"
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.models import WellKnownLLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import ModelConfigurationView
|
||||
from onyx.server.manage.llm.models import ModelConfiguration
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -212,11 +212,11 @@ def get_vertexai_model_names() -> list[str]:
|
||||
|
||||
def model_configurations_for_provider(
|
||||
provider_name: str, llm_recommendations: LLMRecommendations
|
||||
) -> list[ModelConfigurationView]:
|
||||
) -> list[ModelConfiguration]:
|
||||
recommended_visible_models = llm_recommendations.get_visible_models(provider_name)
|
||||
recommended_visible_models_names = [m.name for m in recommended_visible_models]
|
||||
return [
|
||||
ModelConfigurationView(
|
||||
ModelConfiguration(
|
||||
name=model_name,
|
||||
is_visible=model_name in recommended_visible_models_names,
|
||||
max_input_tokens=get_max_input_tokens(model_name, provider_name),
|
||||
|
||||
@@ -3,7 +3,7 @@ from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.server.manage.llm.models import ModelConfigurationView
|
||||
from onyx.server.manage.llm.models import ModelConfiguration
|
||||
|
||||
|
||||
class CustomConfigKeyType(str, Enum):
|
||||
@@ -28,5 +28,5 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
name: str
|
||||
|
||||
# NOTE: the recommended visible models are encoded in the known_models list
|
||||
known_models: list[ModelConfigurationView] = Field(default_factory=list)
|
||||
known_models: list[ModelConfiguration] = Field(default_factory=list)
|
||||
recommended_default_model: SimpleKnownModel | None = None
|
||||
|
||||
@@ -27,8 +27,10 @@ from sqlalchemy.orm import Session as DBSession
|
||||
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.enums import ModelFlowType
|
||||
from onyx.db.enums import SandboxStatus
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_default_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider_by_id
|
||||
from onyx.db.models import BuildMessage
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import User
|
||||
@@ -338,15 +340,25 @@ class SessionManager:
|
||||
)
|
||||
|
||||
# Fallback to system default
|
||||
default_provider = fetch_default_provider(self._db_session)
|
||||
if not default_provider:
|
||||
default_model = fetch_default_model(
|
||||
db_session=self._db_session, flow_type=ModelFlowType.CONVERSATION
|
||||
)
|
||||
if not default_model:
|
||||
raise ValueError("No LLM provider configured")
|
||||
|
||||
provider = fetch_existing_llm_provider_by_id(
|
||||
id=default_model.provider_id,
|
||||
db_session=self._db_session,
|
||||
)
|
||||
|
||||
if not provider:
|
||||
raise ValueError("No LLM provider found")
|
||||
|
||||
return LLMProviderConfig(
|
||||
provider=default_provider.provider,
|
||||
model_name=default_provider.default_model_name,
|
||||
api_key=default_provider.api_key,
|
||||
api_base=default_provider.api_base,
|
||||
provider=provider.provider,
|
||||
model_name=default_model.model_name,
|
||||
api_key=provider.api_key,
|
||||
api_base=provider.api_base,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.llm import fetch_model_configuration_view
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import create_assistant_label
|
||||
from onyx.db.persona import create_update_persona
|
||||
@@ -48,7 +49,7 @@ from onyx.server.features.persona.models import PersonaLabelCreate
|
||||
from onyx.server.features.persona.models import PersonaLabelResponse
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.manage.llm.api import get_valid_model_names_for_persona
|
||||
from onyx.server.manage.llm.api import get_valid_models_for_persona
|
||||
from onyx.server.models import DisplayPriorityRequest
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -479,13 +480,21 @@ def get_persona(
|
||||
)
|
||||
|
||||
# Validate and fix default model if it's no longer valid for this persona's restrictions
|
||||
if persona.llm_model_version_override:
|
||||
valid_models = get_valid_model_names_for_persona(persona_id, user, db_session)
|
||||
if persona.default_model_configuration_id:
|
||||
valid_models = get_valid_models_for_persona(persona_id, user, db_session)
|
||||
|
||||
model_configuration = fetch_model_configuration_view(
|
||||
db_session=db_session,
|
||||
model_configuration_id=persona.default_model_configuration_id,
|
||||
)
|
||||
|
||||
# If current default model is not in the valid list, update to first valid or None
|
||||
if persona.llm_model_version_override not in valid_models:
|
||||
persona.llm_model_version_override = (
|
||||
valid_models[0] if valid_models else None
|
||||
if not (
|
||||
model_configuration
|
||||
and model_configuration.id in [m.id for m in valid_models]
|
||||
):
|
||||
persona.default_model_configuration_id = (
|
||||
valid_models[0].id if valid_models else None
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -109,8 +109,7 @@ class PersonaUpsertRequest(BaseModel):
|
||||
recency_bias: RecencyBiasSetting
|
||||
llm_filter_extraction: bool
|
||||
llm_relevance_filter: bool
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
default_model_configuration_id: int | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
# For Private Personas, who should be able to access these
|
||||
users: list[UUID] = Field(default_factory=list)
|
||||
@@ -162,8 +161,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
# Unique sources from all knowledge (document sets + hierarchy nodes)
|
||||
# Used to populate source filters in chat
|
||||
knowledge_sources: list[DocumentSource]
|
||||
llm_model_version_override: str | None
|
||||
llm_model_provider_override: str | None
|
||||
default_model_configuration_id: int | None
|
||||
|
||||
uploaded_image_id: str | None
|
||||
icon_name: str | None
|
||||
@@ -219,8 +217,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
hierarchy_node_count=len(persona.hierarchy_nodes),
|
||||
attached_document_count=len(persona.attached_documents),
|
||||
knowledge_sources=list(sources),
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
default_model_configuration_id=persona.default_model_configuration_id,
|
||||
uploaded_image_id=persona.uploaded_image_id,
|
||||
icon_name=persona.icon_name,
|
||||
is_public=persona.is_public,
|
||||
@@ -259,8 +256,7 @@ class PersonaSnapshot(BaseModel):
|
||||
users: list[MinimalUserSnapshot]
|
||||
groups: list[int]
|
||||
document_sets: list[DocumentSetSummary]
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
default_model_configuration_id: int | None
|
||||
num_chunks: float | None
|
||||
# Hierarchy nodes attached for scoped search
|
||||
hierarchy_nodes: list[HierarchyNodeSnapshot] = Field(default_factory=list)
|
||||
@@ -318,8 +314,7 @@ class PersonaSnapshot(BaseModel):
|
||||
DocumentSetSummary.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
],
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
default_model_configuration_id=persona.default_model_configuration_id,
|
||||
num_chunks=persona.num_chunks,
|
||||
system_prompt=persona.system_prompt,
|
||||
replace_base_system_prompt=persona.replace_base_system_prompt,
|
||||
@@ -391,8 +386,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
search_start_date=persona.search_start_date,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
default_model_configuration_id=persona.default_model_configuration_id,
|
||||
system_prompt=persona.system_prompt,
|
||||
replace_base_system_prompt=persona.replace_base_system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
|
||||
@@ -141,8 +141,7 @@ def enable_or_disable_kg(
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
document_set_ids=[],
|
||||
tool_ids=[search_tool.id, kg_tool.id],
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
default_model_configuration_id=None,
|
||||
starter_messages=None,
|
||||
users=[user.id],
|
||||
groups=[],
|
||||
|
||||
@@ -11,7 +11,7 @@ from onyx.db.image_generation import get_all_image_generation_configs
|
||||
from onyx.db.image_generation import get_image_generation_config
|
||||
from onyx.db.image_generation import set_default_image_generation_config
|
||||
from onyx.db.image_generation import unset_default_image_generation_config
|
||||
from onyx.db.llm import remove_llm_provider__no_commit
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import User
|
||||
@@ -93,7 +93,6 @@ def _build_llm_provider_request(
|
||||
api_key=source_provider.api_key, # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name, # From request
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -132,7 +131,6 @@ def _build_llm_provider_request(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -164,7 +162,6 @@ def _create_image_gen_llm_provider__no_commit(
|
||||
api_key=provider_request.api_key,
|
||||
api_base=provider_request.api_base,
|
||||
api_version=provider_request.api_version,
|
||||
default_model_name=provider_request.default_model_name,
|
||||
deployment_name=provider_request.deployment_name,
|
||||
is_public=provider_request.is_public,
|
||||
custom_config=provider_request.custom_config,
|
||||
@@ -457,7 +454,9 @@ def update_config(
|
||||
existing_config.model_configuration_id = new_model_config_id
|
||||
|
||||
# 5. Delete old LLM provider (safe now - nothing references it)
|
||||
remove_llm_provider__no_commit(db_session, old_llm_provider_id)
|
||||
remove_llm_provider(
|
||||
db_session=db_session, provider_id=old_llm_provider_id, commit=False
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
db_session.refresh(existing_config)
|
||||
@@ -491,7 +490,9 @@ def delete_config(
|
||||
delete_image_generation_config__no_commit(db_session, image_provider_id)
|
||||
|
||||
# Clean up the orphaned LLM provider (it was exclusively for image gen)
|
||||
remove_llm_provider__no_commit(db_session, llm_provider_id)
|
||||
remove_llm_provider(
|
||||
db_session=db_session, provider_id=llm_provider_id, commit=False
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
except HTTPException:
|
||||
|
||||
@@ -19,14 +19,17 @@ from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import ModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_llm_providers_supporting_flows
|
||||
from onyx.db.llm import fetch_existing_model_configs_for_flow
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import sync_model_configurations
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_text_provider
|
||||
from onyx.db.llm import update_default_vision_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.llm import validate_persona_ids_exist
|
||||
@@ -37,7 +40,6 @@ from onyx.llm.factory import get_llm
|
||||
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import get_bedrock_token_limit
|
||||
from onyx.llm.utils import get_llm_contextual_cost
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
@@ -50,11 +52,13 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationView
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -211,7 +215,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_model_name,
|
||||
model=test_llm_request.model,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -246,13 +250,15 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderView]:
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
llm_provider_list: list[LLMProviderView] = []
|
||||
for llm_provider_model in fetch_existing_llm_providers(
|
||||
db_session, exclude_image_generation_providers=not include_image_gen
|
||||
for llm_provider_model in fetch_existing_llm_providers_supporting_flows(
|
||||
db_session=db_session,
|
||||
flows=[ModelFlowType.CONVERSATION, ModelFlowType.VISION],
|
||||
exclude_image_generation_providers=not include_image_gen,
|
||||
):
|
||||
from_model_start = datetime.now(timezone.utc)
|
||||
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
|
||||
@@ -269,7 +275,18 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
return llm_provider_list
|
||||
default_text = fetch_default_model(
|
||||
db_session=db_session, flow_type=ModelFlowType.CONVERSATION
|
||||
)
|
||||
default_vision = fetch_default_model(
|
||||
db_session=db_session, flow_type=ModelFlowType.VISION
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
providers=llm_provider_list,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -328,20 +345,6 @@ def put_llm_provider(
|
||||
deduplicated_personas.append(persona_id)
|
||||
llm_provider_upsert_request.personas = deduplicated_personas
|
||||
|
||||
default_model_found = False
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
if model_configuration.name == llm_provider_upsert_request.default_model_name:
|
||||
model_configuration.is_visible = True
|
||||
default_model_found = True
|
||||
|
||||
if not default_model_found:
|
||||
llm_provider_upsert_request.model_configurations.append(
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=llm_provider_upsert_request.default_model_name, is_visible=True
|
||||
)
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
@@ -354,6 +357,28 @@ def put_llm_provider(
|
||||
not existing_provider or not existing_provider.is_auto_mode
|
||||
)
|
||||
|
||||
default_vision_model = fetch_default_model(
|
||||
db_session=db_session, flow_type=ModelFlowType.VISION
|
||||
)
|
||||
|
||||
# Check that we're not disabling vision on the default vision model
|
||||
if (
|
||||
default_vision_model
|
||||
and existing_provider
|
||||
and default_vision_model.provider_id == existing_provider.id
|
||||
and any(
|
||||
m.name == default_vision_model.model_name and not m.supports_image_input
|
||||
for m in llm_provider_upsert_request.model_configurations
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=(
|
||||
f"Cannot disable vision for '{default_vision_model.model_name}' because it is the default vision model."
|
||||
"Select another model as default first."
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
result = upsert_llm_provider(
|
||||
llm_provider_upsert_request=llm_provider_upsert_request,
|
||||
@@ -393,31 +418,51 @@ def delete_llm_provider(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
default_model = fetch_default_model(
|
||||
db_session=db_session, flow_type=ModelFlowType.CONVERSATION
|
||||
)
|
||||
if default_model and default_model.provider_id == provider_id:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Cannot delete default provider. Select another provider as default",
|
||||
)
|
||||
default_vision_model = fetch_default_model(
|
||||
db_session=db_session, flow_type=ModelFlowType.VISION
|
||||
)
|
||||
if default_vision_model and default_vision_model.provider_id == provider_id:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Cannot delete default vision provider. Select another provider as default",
|
||||
)
|
||||
|
||||
remove_llm_provider(db_session, provider_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default")
|
||||
def set_provider_as_default(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
@admin_router.post("/default")
|
||||
def set_default_model(
|
||||
default_model_request: DefaultModel,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
update_default_text_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
model=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
provider_id: int,
|
||||
vision_model: str | None = Query(
|
||||
None, description="The default vision model to use"
|
||||
),
|
||||
_: User = Depends(current_admin_user),
|
||||
@admin_router.post("/default-vision")
|
||||
def set_default_vision_model(
|
||||
default_vision_request: DefaultModel,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_vision_provider(
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
provider_id=default_vision_request.provider_id,
|
||||
vision_model=default_vision_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -446,40 +491,34 @@ def get_vision_capable_providers(
|
||||
) -> list[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
vision_providers = []
|
||||
vision_providers = fetch_existing_llm_providers_supporting_flows(
|
||||
db_session=db_session,
|
||||
flows=[ModelFlowType.VISION],
|
||||
)
|
||||
|
||||
logger.info("Fetching vision-capable providers")
|
||||
response: list[VisionProviderResponse] = []
|
||||
|
||||
for provider in providers:
|
||||
vision_models = []
|
||||
for provider in vision_providers:
|
||||
provider_view = LLMProviderView.from_model(provider)
|
||||
_mask_provider_credentials(provider_view)
|
||||
|
||||
# Check each model for vision capability
|
||||
for model_configuration in provider.model_configurations:
|
||||
if model_supports_image_input(model_configuration.name, provider.provider):
|
||||
vision_models.append(model_configuration.name)
|
||||
logger.debug(
|
||||
f"Vision model found: {provider.provider}/{model_configuration.name}"
|
||||
)
|
||||
|
||||
# Only include providers with at least one vision-capable model
|
||||
if vision_models:
|
||||
provider_view = LLMProviderView.from_model(provider)
|
||||
_mask_provider_credentials(provider_view)
|
||||
|
||||
vision_providers.append(
|
||||
VisionProviderResponse(
|
||||
**provider_view.model_dump(),
|
||||
vision_models=vision_models,
|
||||
)
|
||||
response.append(
|
||||
VisionProviderResponse(
|
||||
**provider_view.model_dump(),
|
||||
vision_models=[
|
||||
model.name
|
||||
for model in provider.model_configurations
|
||||
if model.supports_image_input
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Vision provider: {provider.provider} with models: {vision_models}"
|
||||
)
|
||||
logger.info(
|
||||
f"Vision provider: {provider.provider} with models: {response[-1].vision_models}"
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(vision_providers)} vision-capable providers")
|
||||
return vision_providers
|
||||
return response
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -489,7 +528,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -502,7 +541,10 @@ def list_llm_provider_basics(
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch user-accessible LLM providers")
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
all_providers = fetch_existing_llm_providers_supporting_flows(
|
||||
db_session=db_session,
|
||||
flows=[ModelFlowType.CONVERSATION, ModelFlowType.VISION],
|
||||
)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user)
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
|
||||
@@ -526,14 +568,25 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return accessible_providers
|
||||
default_text = fetch_default_model(
|
||||
db_session=db_session, flow_type=ModelFlowType.CONVERSATION
|
||||
)
|
||||
default_vision = fetch_default_model(
|
||||
db_session=db_session, flow_type=ModelFlowType.VISION
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
def get_valid_models_for_persona(
|
||||
persona_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> list[str]:
|
||||
) -> list[ModelConfigurationView]:
|
||||
"""Get all valid model names that a user can access for this persona.
|
||||
|
||||
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
|
||||
@@ -545,7 +598,10 @@ def get_valid_model_names_for_persona(
|
||||
return []
|
||||
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
all_providers = fetch_existing_llm_providers_supporting_flows(
|
||||
db_session=db_session,
|
||||
flows=[ModelFlowType.CONVERSATION],
|
||||
)
|
||||
user_group_ids = set() if is_admin else fetch_user_group_ids(db_session, user)
|
||||
|
||||
valid_models = []
|
||||
@@ -557,7 +613,12 @@ def get_valid_model_names_for_persona(
|
||||
# Collect all model names from this provider
|
||||
for model_config in llm_provider_model.model_configurations:
|
||||
if model_config.is_visible:
|
||||
valid_models.append(model_config.name)
|
||||
valid_models.append(
|
||||
ModelConfigurationView.from_model(
|
||||
model_configuration_model=model_config,
|
||||
provider=llm_provider_model.provider,
|
||||
)
|
||||
)
|
||||
|
||||
return valid_models
|
||||
|
||||
@@ -592,7 +653,10 @@ def list_llm_providers_for_persona(
|
||||
)
|
||||
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
all_providers = fetch_existing_llm_providers_supporting_flows(
|
||||
db_session=db_session,
|
||||
flows=[ModelFlowType.CONVERSATION],
|
||||
)
|
||||
user_group_ids = set() if is_admin else fetch_user_group_ids(db_session, user)
|
||||
|
||||
llm_provider_list: list[LLMProviderDescriptor] = []
|
||||
@@ -630,32 +694,34 @@ def get_provider_contextual_cost(
|
||||
- the chunk_context
|
||||
- The per-token cost of the LLM used to generate the doc_summary and chunk_context
|
||||
"""
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
models = fetch_existing_model_configs_for_flow(
|
||||
db_session=db_session,
|
||||
flows=[ModelFlowType.CONVERSATION],
|
||||
)
|
||||
costs = []
|
||||
for provider in providers:
|
||||
for model_configuration in provider.model_configurations:
|
||||
llm_provider = LLMProviderView.from_model(provider)
|
||||
llm = get_llm(
|
||||
provider=provider.provider,
|
||||
model=model_configuration.name,
|
||||
deployment_name=provider.deployment_name,
|
||||
api_key=provider.api_key,
|
||||
api_base=provider.api_base,
|
||||
api_version=provider.api_version,
|
||||
custom_config=provider.custom_config,
|
||||
max_input_tokens=get_max_input_tokens_from_llm_provider(
|
||||
llm_provider=llm_provider, model_name=model_configuration.name
|
||||
),
|
||||
)
|
||||
cost = get_llm_contextual_cost(llm)
|
||||
costs.append(
|
||||
LLMCost(
|
||||
provider=provider.name,
|
||||
model_name=model_configuration.name,
|
||||
cost=cost,
|
||||
)
|
||||
)
|
||||
|
||||
for model in models:
|
||||
llm_provider = LLMProviderView.from_model(model.llm_provider)
|
||||
llm = get_llm(
|
||||
provider=model.llm_provider.provider,
|
||||
model=model.name,
|
||||
deployment_name=model.llm_provider.deployment_name,
|
||||
api_key=model.llm_provider.api_key,
|
||||
api_base=model.llm_provider.api_base,
|
||||
api_version=model.llm_provider.api_version,
|
||||
custom_config=model.llm_provider.custom_config,
|
||||
max_input_tokens=get_max_input_tokens_from_llm_provider(
|
||||
llm_provider=llm_provider, model_name=model.name
|
||||
),
|
||||
)
|
||||
cost = get_llm_contextual_cost(llm)
|
||||
costs.append(
|
||||
LLMCost(
|
||||
provider=model.llm_provider.provider,
|
||||
model_name=model.name,
|
||||
cost=cost,
|
||||
)
|
||||
)
|
||||
return costs
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.db.enums import ModelFlowType
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.llm.utils import litellm_thinks_model_supports_image_input
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
@@ -21,22 +26,23 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T", bound="LLMProviderDescriptor | LLMProviderView")
|
||||
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
name: str | None = None
|
||||
name: str
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
# if try and use the existing API key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
|
||||
@@ -54,10 +60,6 @@ class LLMProviderDescriptor(BaseModel):
|
||||
name: str
|
||||
provider: str
|
||||
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
|
||||
default_model_name: str
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -75,10 +77,6 @@ class LLMProviderDescriptor(BaseModel):
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
@@ -92,13 +90,11 @@ class LLMProvider(BaseModel):
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
is_public: bool = True
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
personas: list[int] = Field(default_factory=list)
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
@@ -119,8 +115,6 @@ class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -150,10 +144,6 @@ class LLMProviderView(LLMProvider):
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
is_public=llm_provider_model.is_public,
|
||||
is_auto_mode=llm_provider_model.is_auto_mode,
|
||||
groups=groups,
|
||||
@@ -165,13 +155,15 @@ class LLMProviderView(LLMProvider):
|
||||
)
|
||||
|
||||
|
||||
class ModelConfigurationUpsertRequest(BaseModel):
|
||||
class ModelConfiguration(BaseModel):
|
||||
name: str
|
||||
is_visible: bool
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
display_name: str | None = None # For dynamic providers, from source API
|
||||
|
||||
|
||||
class ModelConfigurationUpsertRequest(ModelConfiguration):
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model_configuration_model: "ModelConfigurationModel"
|
||||
@@ -180,18 +172,15 @@ class ModelConfigurationUpsertRequest(BaseModel):
|
||||
name=model_configuration_model.name,
|
||||
is_visible=model_configuration_model.is_visible,
|
||||
max_input_tokens=model_configuration_model.max_input_tokens,
|
||||
supports_image_input=model_configuration_model.supports_image_input,
|
||||
supports_image_input=ModelFlowType.VISION
|
||||
in (flow for flow in model_configuration_model.model_flow_types),
|
||||
display_name=model_configuration_model.display_name,
|
||||
)
|
||||
|
||||
|
||||
class ModelConfigurationView(BaseModel):
|
||||
name: str
|
||||
is_visible: bool
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool
|
||||
class ModelConfigurationView(ModelConfiguration):
|
||||
id: int
|
||||
supports_reasoning: bool = False
|
||||
display_name: str | None = None
|
||||
provider_display_name: str | None = None
|
||||
vendor: str | None = None
|
||||
version: str | None = None
|
||||
@@ -201,26 +190,23 @@ class ModelConfigurationView(BaseModel):
|
||||
def from_model(
|
||||
cls,
|
||||
model_configuration_model: "ModelConfigurationModel",
|
||||
provider_name: str,
|
||||
provider: str,
|
||||
) -> "ModelConfigurationView":
|
||||
# For dynamic providers (OpenRouter, Bedrock, Ollama), use the display_name
|
||||
# stored in DB from the source API. Skip LiteLLM parsing entirely.
|
||||
if (
|
||||
provider_name in DYNAMIC_LLM_PROVIDERS
|
||||
and model_configuration_model.display_name
|
||||
):
|
||||
if provider in DYNAMIC_LLM_PROVIDERS and model_configuration_model.display_name:
|
||||
# Extract vendor from model name for grouping (e.g., "Anthropic", "OpenAI")
|
||||
vendor = extract_vendor_from_model_name(
|
||||
model_configuration_model.name, provider_name
|
||||
model_configuration_model.name, provider
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=model_configuration_model.id,
|
||||
name=model_configuration_model.name,
|
||||
is_visible=model_configuration_model.is_visible,
|
||||
max_input_tokens=model_configuration_model.max_input_tokens,
|
||||
supports_image_input=(
|
||||
model_configuration_model.supports_image_input or False
|
||||
),
|
||||
supports_image_input=ModelFlowType.VISION
|
||||
in (flow for flow in model_configuration_model.model_flow_types),
|
||||
# Infer reasoning support from model name/display name
|
||||
supports_reasoning=is_reasoning_model(
|
||||
model_configuration_model.name,
|
||||
@@ -239,8 +225,8 @@ class ModelConfigurationView(BaseModel):
|
||||
# Parse the model name to get display information
|
||||
# Include provider prefix if not already present (enrichments use full keys like "vertex_ai/...")
|
||||
model_name = model_configuration_model.name
|
||||
if provider_name and not model_name.startswith(f"{provider_name}/"):
|
||||
model_name = f"{provider_name}/{model_name}"
|
||||
if provider and not model_name.startswith(f"{provider}/"):
|
||||
model_name = f"{provider}/{model_name}"
|
||||
parsed = parse_litellm_model_name(model_name)
|
||||
|
||||
# Include region in display name for Bedrock cross-region models
|
||||
@@ -251,24 +237,29 @@ class ModelConfigurationView(BaseModel):
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=model_configuration_model.id,
|
||||
name=model_configuration_model.name,
|
||||
is_visible=model_configuration_model.is_visible,
|
||||
max_input_tokens=(
|
||||
model_configuration_model.max_input_tokens
|
||||
or get_max_input_tokens(
|
||||
model_name=model_configuration_model.name,
|
||||
model_provider=provider_name,
|
||||
model_provider=provider,
|
||||
)
|
||||
),
|
||||
supports_image_input=(
|
||||
val
|
||||
if (val := model_configuration_model.supports_image_input) is not None
|
||||
if (
|
||||
val := ModelFlowType.VISION
|
||||
in (flow for flow in model_configuration_model.model_flow_types)
|
||||
)
|
||||
is not None
|
||||
else litellm_thinks_model_supports_image_input(
|
||||
model_configuration_model.name, provider_name
|
||||
model_configuration_model.name, provider
|
||||
)
|
||||
),
|
||||
supports_reasoning=model_is_reasoning_model(
|
||||
model_configuration_model.name, provider_name
|
||||
model_configuration_model.name, provider
|
||||
),
|
||||
# Populate display fields from parsed model name
|
||||
display_name=display_name,
|
||||
@@ -371,3 +362,27 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> LLMProviderResponse[T]:
|
||||
return cls(
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
providers=providers,
|
||||
)
|
||||
|
||||
@@ -21,10 +21,11 @@ from onyx.db.connector_credential_pair import resync_cc_pair
|
||||
from onyx.db.credentials import create_initial_public_credential
|
||||
from onyx.db.document import check_docs_exist
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.enums import ModelFlowType
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import fetch_default_model
|
||||
from onyx.db.llm import update_default_text_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -286,7 +287,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
|
||||
if GEN_AI_API_KEY and fetch_default_model(db_session, ModelFlowType.TEXT) is None:
|
||||
# Only for dev flows
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
@@ -298,7 +299,6 @@ def setup_postgres(db_session: Session) -> None:
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -310,7 +310,9 @@ def setup_postgres(db_session: Session) -> None:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
update_default_text_provider(
|
||||
provider_id=new_llm_provider.id, model=llm_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
|
||||
@@ -17,7 +17,6 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -43,7 +42,6 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
|
||||
@@ -8,7 +8,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_text_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
@@ -27,14 +27,15 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=os.environ.get("OPENAI_API_KEY", "test"),
|
||||
is_public=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
groups=[],
|
||||
)
|
||||
provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=llm_provider_request,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_text_provider(
|
||||
provider_id=provider.id, model="gpt-4o-mini", db_session=db_session
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
|
||||
print(f"Note: Could not create LLM provider: {exc}")
|
||||
|
||||
|
||||
@@ -10,9 +10,10 @@ from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.enums import ModelFlowType
|
||||
from onyx.db.llm import fetch_existing_llm_providers_supporting_flows
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_text_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
@@ -36,7 +37,10 @@ def test_answer_with_only_anthropic_provider(
|
||||
assert anthropic_api_key, "ANTHROPIC_API_KEY environment variable must be set"
|
||||
|
||||
# Drop any existing providers so that only Anthropic is available.
|
||||
for provider in fetch_existing_llm_providers(db_session):
|
||||
for provider in fetch_existing_llm_providers_supporting_flows(
|
||||
db_session=db_session,
|
||||
flows=[ModelFlowType.CONVERSATION],
|
||||
):
|
||||
remove_llm_provider(db_session, provider.id)
|
||||
|
||||
anthropic_model = "claude-haiku-4-5-20251001"
|
||||
@@ -47,7 +51,6 @@ def test_answer_with_only_anthropic_provider(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.ANTHROPIC,
|
||||
api_key=anthropic_api_key,
|
||||
default_model_name=anthropic_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -59,7 +62,11 @@ def test_answer_with_only_anthropic_provider(
|
||||
)
|
||||
|
||||
try:
|
||||
update_default_provider(anthropic_provider.id, db_session)
|
||||
update_default_text_provider(
|
||||
provider_id=anthropic_provider.id,
|
||||
model=anthropic_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
test_user = create_test_user(db_session, email_prefix="anthropic_only")
|
||||
chat_session = create_chat_session(
|
||||
|
||||
@@ -79,8 +79,7 @@ def test_stream_chat_message_objects_without_web_search(
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
default_model_configuration_id=None,
|
||||
starter_messages=None,
|
||||
system_prompt=None,
|
||||
task_prompt=None,
|
||||
|
||||
@@ -5,6 +5,8 @@ from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.db.llm import update_default_text_provider
|
||||
|
||||
# Set environment variables to disable model server for testing
|
||||
os.environ["DISABLE_MODEL_SERVER"] = "true"
|
||||
os.environ["MODEL_SERVER_HOST"] = "disabled"
|
||||
@@ -418,13 +420,15 @@ class TestSlackBotFederatedSearch:
|
||||
name=f"test-llm-provider-{uuid4().hex[:8]}",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
default_model_name="gpt-4o",
|
||||
is_default_provider=True,
|
||||
is_public=True,
|
||||
)
|
||||
db_session.add(llm_provider)
|
||||
db_session.commit()
|
||||
|
||||
update_default_text_provider(
|
||||
provider_id=llm_provider.id, model="gpt-4o", db_session=db_session
|
||||
)
|
||||
|
||||
def _teardown_common_mocks(self, patches: list) -> None:
|
||||
"""Stop all patches"""
|
||||
for p in patches:
|
||||
|
||||
@@ -6,6 +6,7 @@ import requests
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
@@ -25,6 +26,7 @@ class LLMProviderManager:
|
||||
personas: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
set_as_default: bool = True,
|
||||
model_configurations: list[ModelConfigurationUpsertRequest] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestLLMProvider:
|
||||
email = "Unknown"
|
||||
@@ -36,7 +38,6 @@ class LLMProviderManager:
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
provider=provider or LlmProviderNames.OPENAI,
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
api_key=api_key or os.environ["OPENAI_API_KEY"],
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
@@ -44,7 +45,7 @@ class LLMProviderManager:
|
||||
is_public=True if is_public is None else is_public,
|
||||
groups=groups or [],
|
||||
personas=personas or [],
|
||||
model_configurations=[],
|
||||
model_configurations=model_configurations or [],
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
@@ -138,11 +139,25 @@ class LLMProviderManager:
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def build_model_configuration_upsert_request(
|
||||
name: str,
|
||||
is_visible: bool,
|
||||
max_input_tokens: int | None,
|
||||
supports_image_input: bool,
|
||||
display_name: str | None,
|
||||
) -> ModelConfigurationUpsertRequest:
|
||||
return ModelConfigurationUpsertRequest(
|
||||
name=name,
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
supports_image_input=supports_image_input,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
@@ -28,8 +28,7 @@ class PersonaManager:
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
llm_model_provider_override: str | None = None,
|
||||
llm_model_version_override: str | None = None,
|
||||
default_model_configuration_id: int | None = None,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
@@ -55,8 +54,7 @@ class PersonaManager:
|
||||
recency_bias=recency_bias,
|
||||
document_set_ids=document_set_ids or [],
|
||||
tool_ids=tool_ids or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
default_model_configuration_id=default_model_configuration_id,
|
||||
users=[UUID(user) for user in (users or [])],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
@@ -90,8 +88,7 @@ class PersonaManager:
|
||||
datetime_aware=datetime_aware,
|
||||
document_set_ids=document_set_ids or [],
|
||||
tool_ids=tool_ids or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
default_model_configuration_id=default_model_configuration_id,
|
||||
users=users or [],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
@@ -112,8 +109,7 @@ class PersonaManager:
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
llm_model_provider_override: str | None = None,
|
||||
llm_model_version_override: str | None = None,
|
||||
default_model_configuration_id: int | None = None,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
@@ -137,11 +133,8 @@ class PersonaManager:
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
document_set_ids=document_set_ids or persona.document_set_ids,
|
||||
tool_ids=tool_ids or persona.tool_ids,
|
||||
llm_model_provider_override=(
|
||||
llm_model_provider_override or persona.llm_model_provider_override
|
||||
),
|
||||
llm_model_version_override=(
|
||||
llm_model_version_override or persona.llm_model_version_override
|
||||
default_model_configuration_id=(
|
||||
default_model_configuration_id or persona.default_model_configuration_id
|
||||
),
|
||||
users=[UUID(user) for user in (users or persona.users)],
|
||||
groups=groups or persona.groups,
|
||||
@@ -174,11 +167,8 @@ class PersonaManager:
|
||||
datetime_aware=datetime_aware,
|
||||
document_set_ids=updated_persona_data["document_sets"],
|
||||
tool_ids=updated_persona_data["tools"],
|
||||
llm_model_provider_override=updated_persona_data[
|
||||
"llm_model_provider_override"
|
||||
],
|
||||
llm_model_version_override=updated_persona_data[
|
||||
"llm_model_version_override"
|
||||
default_model_configuration_id=updated_persona_data[
|
||||
"default_model_configuration_id"
|
||||
],
|
||||
users=[user["email"] for user in updated_persona_data["users"]],
|
||||
groups=updated_persona_data["groups"],
|
||||
@@ -267,25 +257,14 @@ class PersonaManager:
|
||||
)
|
||||
)
|
||||
if (
|
||||
fetched_persona.llm_model_provider_override
|
||||
!= persona.llm_model_provider_override
|
||||
fetched_persona.default_model_configuration_id
|
||||
!= persona.default_model_configuration_id
|
||||
):
|
||||
mismatches.append(
|
||||
(
|
||||
"llm_model_provider_override",
|
||||
persona.llm_model_provider_override,
|
||||
fetched_persona.llm_model_provider_override,
|
||||
)
|
||||
)
|
||||
if (
|
||||
fetched_persona.llm_model_version_override
|
||||
!= persona.llm_model_version_override
|
||||
):
|
||||
mismatches.append(
|
||||
(
|
||||
"llm_model_version_override",
|
||||
persona.llm_model_version_override,
|
||||
fetched_persona.llm_model_version_override,
|
||||
"default_model_configuration_id",
|
||||
persona.default_model_configuration_id,
|
||||
fetched_persona.default_model_configuration_id,
|
||||
)
|
||||
)
|
||||
if fetched_persona.system_prompt != persona.system_prompt:
|
||||
|
||||
@@ -111,6 +111,15 @@ class DATestUserGroup(BaseModel):
|
||||
cc_pair_ids: list[int]
|
||||
|
||||
|
||||
class DATestModelConfiguration(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
is_visible: bool
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool | None
|
||||
display_name: str | None
|
||||
|
||||
|
||||
class DATestLLMProvider(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
@@ -123,6 +132,7 @@ class DATestLLMProvider(BaseModel):
|
||||
personas: list[int]
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
model_configurations: list[DATestModelConfiguration] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DATestImageGenerationConfig(BaseModel):
|
||||
@@ -157,8 +167,7 @@ class DATestPersona(BaseModel):
|
||||
recency_bias: RecencyBiasSetting
|
||||
document_set_ids: list[int]
|
||||
tool_ids: list[int]
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
default_model_configuration_id: int | None
|
||||
users: list[str]
|
||||
groups: list[int]
|
||||
label_ids: list[int]
|
||||
|
||||
@@ -42,7 +42,6 @@ def _create_provider_with_api(
|
||||
llm_provider_data = {
|
||||
"name": name,
|
||||
"provider": provider_type,
|
||||
"default_model_name": default_model,
|
||||
"api_key": "test-api-key-for-auto-mode-testing",
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
|
||||
@@ -96,6 +96,8 @@ def assert_response_is_equivalent(
|
||||
for name in actual_by_name:
|
||||
actual_config = actual_by_name[name]
|
||||
expected_config = expected_by_name[name]
|
||||
assert "id" in actual_config, f"Config {name} is missing id"
|
||||
del actual_config["id"] # This will not be in the expected config
|
||||
assert actual_config == expected_config, (
|
||||
f"Config mismatch for {name}:\n"
|
||||
f"Actual: {actual_config}\n"
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import LLMProvider__Persona
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
@@ -48,22 +49,40 @@ def _create_llm_provider(
|
||||
default_model_name=default_model_name,
|
||||
deployment_name=None,
|
||||
is_public=is_public,
|
||||
# Use None instead of False to avoid unique constraint violation
|
||||
# The is_default_provider column has unique=True, so only one True and one False allowed
|
||||
is_default_provider=is_default if is_default else None,
|
||||
is_default_vision_provider=False,
|
||||
default_vision_model=None,
|
||||
)
|
||||
db_session.add(provider)
|
||||
db_session.flush()
|
||||
return provider
|
||||
|
||||
|
||||
def _create_model_configuration(
|
||||
db_session: Session,
|
||||
*,
|
||||
llm_provider_id: int,
|
||||
name: str,
|
||||
is_visible: bool,
|
||||
max_input_tokens: int | None,
|
||||
supports_image_input: bool,
|
||||
display_name: str | None,
|
||||
) -> ModelConfiguration:
|
||||
model_configuration = ModelConfiguration(
|
||||
llm_provider_id=llm_provider_id,
|
||||
name=name,
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
supports_image_input=supports_image_input,
|
||||
display_name=display_name,
|
||||
)
|
||||
db_session.add(model_configuration)
|
||||
db_session.flush()
|
||||
return model_configuration
|
||||
|
||||
|
||||
def _create_persona(
|
||||
db_session: Session,
|
||||
*,
|
||||
name: str,
|
||||
provider_name: str,
|
||||
default_model_configuration_id: int,
|
||||
) -> Persona:
|
||||
persona = Persona(
|
||||
name=name,
|
||||
@@ -74,8 +93,7 @@ def _create_persona(
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
llm_model_provider_override=provider_name,
|
||||
llm_model_version_override="gpt-4o-mini",
|
||||
default_model_configuration_id=default_model_configuration_id,
|
||||
system_prompt="System prompt",
|
||||
task_prompt="Task prompt",
|
||||
datetime_aware=True,
|
||||
@@ -114,6 +132,15 @@ def test_can_user_access_llm_provider_or_logic(
|
||||
is_public=True,
|
||||
is_default=True,
|
||||
)
|
||||
_create_model_configuration(
|
||||
db_session,
|
||||
llm_provider_id=default_provider.id,
|
||||
name="default-model",
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
supports_image_input=False,
|
||||
display_name=None,
|
||||
)
|
||||
# Locked provider - is_public=False with no restrictions
|
||||
locked_provider = _create_llm_provider(
|
||||
db_session,
|
||||
@@ -122,6 +149,15 @@ def test_can_user_access_llm_provider_or_logic(
|
||||
is_public=False,
|
||||
is_default=False,
|
||||
)
|
||||
_create_model_configuration(
|
||||
db_session,
|
||||
llm_provider_id=locked_provider.id,
|
||||
name="locked-model",
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
supports_image_input=False,
|
||||
display_name=None,
|
||||
)
|
||||
# Restricted provider - has both group AND persona restrictions (AND logic)
|
||||
restricted_provider = _create_llm_provider(
|
||||
db_session,
|
||||
@@ -130,16 +166,34 @@ def test_can_user_access_llm_provider_or_logic(
|
||||
is_public=False,
|
||||
is_default=False,
|
||||
)
|
||||
restricted_model_1 = _create_model_configuration(
|
||||
db_session,
|
||||
llm_provider_id=restricted_provider.id,
|
||||
name="restricted-model-1",
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
supports_image_input=False,
|
||||
display_name=None,
|
||||
)
|
||||
restricted_model_2 = _create_model_configuration(
|
||||
db_session,
|
||||
llm_provider_id=restricted_provider.id,
|
||||
name="restricted-model-2",
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
supports_image_input=False,
|
||||
display_name=None,
|
||||
)
|
||||
|
||||
allowed_persona = _create_persona(
|
||||
db_session,
|
||||
name="allowed-persona",
|
||||
provider_name=restricted_provider.name,
|
||||
default_model_configuration_id=restricted_model_1.id,
|
||||
)
|
||||
blocked_persona = _create_persona(
|
||||
db_session,
|
||||
name="blocked-persona",
|
||||
provider_name=restricted_provider.name,
|
||||
default_model_configuration_id=restricted_model_2.id,
|
||||
)
|
||||
|
||||
access_group = UserGroup(name="access-group")
|
||||
@@ -253,6 +307,15 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
is_public=True,
|
||||
is_default=True,
|
||||
)
|
||||
_create_model_configuration(
|
||||
db_session,
|
||||
llm_provider_id=default_provider.id,
|
||||
name="default-model",
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
supports_image_input=False,
|
||||
display_name=None,
|
||||
)
|
||||
restricted_provider = _create_llm_provider(
|
||||
db_session,
|
||||
name="restricted-provider",
|
||||
@@ -260,11 +323,19 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
is_public=False,
|
||||
is_default=False,
|
||||
)
|
||||
|
||||
restricted_model_configuration = _create_model_configuration(
|
||||
db_session,
|
||||
llm_provider_id=restricted_provider.id,
|
||||
name="restricted-model",
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
supports_image_input=False,
|
||||
display_name=None,
|
||||
)
|
||||
persona = _create_persona(
|
||||
db_session,
|
||||
name="fallback-persona",
|
||||
provider_name=restricted_provider.name,
|
||||
default_model_configuration_id=restricted_model_configuration.id,
|
||||
)
|
||||
|
||||
access_group = UserGroup(name="persona-group")
|
||||
@@ -300,13 +371,13 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
persona=persona,
|
||||
user=admin_model,
|
||||
)
|
||||
assert allowed_llm.config.model_name == restricted_provider.default_model_name
|
||||
assert allowed_llm.config.model_name == "gpt-4o-mini"
|
||||
|
||||
fallback_llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=basic_model,
|
||||
)
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
assert fallback_llm.config.model_name == "gpt-4o"
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
@@ -382,9 +453,19 @@ def test_provider_delete_clears_persona_references(reset: None) -> None:
|
||||
is_public=False,
|
||||
set_as_default=False,
|
||||
user_performing_action=admin_user,
|
||||
model_configurations=[
|
||||
LLMProviderManager.build_model_configuration_upsert_request(
|
||||
name="model-1",
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
supports_image_input=False,
|
||||
display_name=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
persona = PersonaManager.create(
|
||||
llm_model_provider_override=provider.name,
|
||||
default_model_configuration_id=provider.model_configurations[0].id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -394,11 +475,11 @@ def test_provider_delete_clears_persona_references(reset: None) -> None:
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify the persona now falls back to default (llm_model_provider_override cleared)
|
||||
# Verify the persona now falls back to default (default_model_configuration_id cleared)
|
||||
persona_response = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert persona_response.status_code == 200
|
||||
updated_persona = persona_response.json()
|
||||
assert updated_persona["llm_model_provider_override"] is None
|
||||
assert updated_persona["default_model_configuration_id"] is None
|
||||
|
||||
@@ -47,7 +47,8 @@ class TestSyncModelConfigurations:
|
||||
|
||||
assert result == 2 # Two new models
|
||||
assert mock_session.execute.call_count == 2
|
||||
mock_session.commit.assert_called_once()
|
||||
# Single commit at end for atomicity (flow mappings use flush)
|
||||
assert mock_session.commit.call_count == 1
|
||||
|
||||
def test_skips_existing_models(self) -> None:
|
||||
"""Test that existing models are not overwritten."""
|
||||
|
||||
@@ -109,7 +109,9 @@ class TestGetOllamaAvailableModels:
|
||||
|
||||
# Verify DB operations were called
|
||||
assert mock_session.execute.call_count == 3 # 3 models inserted
|
||||
mock_session.commit.assert_called_once()
|
||||
assert (
|
||||
mock_session.commit.call_count == 6
|
||||
) # 3 for model config + 3 for flow mapping
|
||||
|
||||
def test_no_sync_when_provider_name_not_specified(
|
||||
self, mock_ollama_tags_response: dict, mock_ollama_show_response: dict
|
||||
@@ -251,7 +253,7 @@ class TestGetOpenRouterAvailableModels:
|
||||
|
||||
# Verify DB operations were called
|
||||
assert mock_session.execute.call_count == 3 # 3 models inserted
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.commit.call_count == 6 # 3 for model config + 3 for flow mapping
|
||||
|
||||
def test_preserves_existing_models_on_sync(
|
||||
self, mock_openrouter_response: dict
|
||||
|
||||
@@ -43,8 +43,7 @@ export interface MinimalPersonaSnapshot {
|
||||
// Unique sources from all knowledge (document sets + hierarchy nodes)
|
||||
// Used to populate source filters in chat
|
||||
knowledge_sources?: ValidSources[];
|
||||
llm_model_version_override?: string;
|
||||
llm_model_provider_override?: string;
|
||||
default_model_configuration_id?: number;
|
||||
|
||||
uploaded_image_id?: string;
|
||||
icon_name?: string;
|
||||
|
||||
@@ -16,8 +16,7 @@ interface PersonaUpsertRequest {
|
||||
recency_bias: string;
|
||||
llm_filter_extraction: boolean;
|
||||
llm_relevance_filter: boolean | null;
|
||||
llm_model_provider_override: string | null;
|
||||
llm_model_version_override: string | null;
|
||||
default_model_configuration_id: number | null;
|
||||
starter_messages: StarterMessage[] | null;
|
||||
users?: string[];
|
||||
groups: number[];
|
||||
@@ -48,8 +47,7 @@ export interface PersonaUpsertParameters {
|
||||
num_chunks: number | null;
|
||||
is_public: boolean;
|
||||
llm_relevance_filter: boolean | null;
|
||||
llm_model_provider_override: string | null;
|
||||
llm_model_version_override: string | null;
|
||||
default_model_configuration_id: number | null;
|
||||
starter_messages: StarterMessage[] | null;
|
||||
users?: string[];
|
||||
groups: number[];
|
||||
@@ -88,8 +86,7 @@ function buildPersonaUpsertRequest({
|
||||
uploaded_image_id,
|
||||
is_default_persona,
|
||||
llm_relevance_filter,
|
||||
llm_model_provider_override,
|
||||
llm_model_version_override,
|
||||
default_model_configuration_id,
|
||||
starter_messages,
|
||||
label_ids,
|
||||
replace_base_system_prompt,
|
||||
@@ -114,8 +111,7 @@ function buildPersonaUpsertRequest({
|
||||
recency_bias: "base_decay",
|
||||
llm_filter_extraction: false,
|
||||
llm_relevance_filter: llm_relevance_filter ?? null,
|
||||
llm_model_provider_override: llm_model_provider_override ?? null,
|
||||
llm_model_version_override: llm_model_version_override ?? null,
|
||||
default_model_configuration_id: default_model_configuration_id ?? null,
|
||||
starter_messages: starter_messages ?? null,
|
||||
display_priority: null,
|
||||
label_ids: label_ids ?? null,
|
||||
|
||||
@@ -144,6 +144,7 @@ export function ModelConfigurationField({
|
||||
<CreateButton
|
||||
onClick={() => {
|
||||
arrayHelpers.push({
|
||||
id: -1,
|
||||
name: "",
|
||||
is_visible: true,
|
||||
// Use null so Yup.number().nullable() accepts empty inputs
|
||||
|
||||
@@ -214,6 +214,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
default_model_name: "claude-3-opus",
|
||||
model_configurations: [
|
||||
{
|
||||
id: 1,
|
||||
name: "claude-3-opus",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
Formik,
|
||||
ErrorMessage,
|
||||
} from "formik";
|
||||
import { LLMProviderFormProps, LLMProviderView } from "../interfaces";
|
||||
import { LLMProviderFormProps } from "../interfaces";
|
||||
import * as Yup from "yup";
|
||||
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
|
||||
import { DisplayNameField } from "./components/DisplayNameField";
|
||||
@@ -69,7 +69,7 @@ export function CustomForm({
|
||||
...modelConfiguration,
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
})
|
||||
) ?? [{ name: "", is_visible: true, max_input_tokens: null }],
|
||||
) ?? [{ id: -1, name: "", is_visible: true, max_input_tokens: null }],
|
||||
custom_config_list: existingLlmProvider?.custom_config
|
||||
? Object.entries(existingLlmProvider.custom_config)
|
||||
: [],
|
||||
|
||||
@@ -14,8 +14,11 @@ export interface ModelConfiguration {
|
||||
is_visible: boolean;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean | null;
|
||||
supports_reasoning?: boolean;
|
||||
display_name?: string;
|
||||
}
|
||||
export interface ModelConfigurationView extends ModelConfiguration {
|
||||
id: number;
|
||||
supports_reasoning?: boolean;
|
||||
provider_display_name?: string;
|
||||
vendor?: string;
|
||||
version?: string;
|
||||
@@ -55,7 +58,7 @@ export interface LLMProvider {
|
||||
deployment_name: string | null;
|
||||
default_vision_model: string | null;
|
||||
is_default_vision_provider: boolean | null;
|
||||
model_configurations: ModelConfiguration[];
|
||||
model_configurations: ModelConfigurationView[];
|
||||
}
|
||||
|
||||
export interface LLMProviderView extends LLMProvider {
|
||||
@@ -78,7 +81,7 @@ export interface LLMProviderDescriptor {
|
||||
is_public?: boolean;
|
||||
groups?: number[];
|
||||
personas?: number[];
|
||||
model_configurations: ModelConfiguration[];
|
||||
model_configurations: ModelConfigurationView[];
|
||||
}
|
||||
|
||||
export interface OllamaModelResponse {
|
||||
|
||||
@@ -618,9 +618,11 @@ export function useLlmManager(
|
||||
setCurrentLlm(
|
||||
getValidLlmDescriptor(currentChatSession.current_alternate_model)
|
||||
);
|
||||
} else if (liveAssistant?.llm_model_version_override) {
|
||||
} else if (liveAssistant?.default_model_configuration_id) {
|
||||
setCurrentLlm(
|
||||
getValidLlmDescriptor(liveAssistant.llm_model_version_override)
|
||||
getValidLlmDescriptorFromModelId(
|
||||
liveAssistant.default_model_configuration_id
|
||||
)
|
||||
);
|
||||
} else if (userHasManuallyOverriddenLLM) {
|
||||
// if the user has an override and there's nothing special about the
|
||||
@@ -640,6 +642,17 @@ export function useLlmManager(
|
||||
setChatSession(currentChatSession || null);
|
||||
};
|
||||
|
||||
function getValidLlmDescriptorFromModelId(
|
||||
modelId: number | null | undefined
|
||||
): LlmDescriptor {
|
||||
// Get the name of the model
|
||||
const model = llmProviders
|
||||
?.find((p) => p.model_configurations.find((m) => m.id === modelId))
|
||||
?.model_configurations.find((m) => m.id === modelId);
|
||||
|
||||
return getValidLlmDescriptor(model?.name);
|
||||
}
|
||||
|
||||
function getValidLlmDescriptor(
|
||||
modelName: string | null | undefined
|
||||
): LlmDescriptor {
|
||||
|
||||
@@ -18,15 +18,20 @@ export function getFinalLLM(
|
||||
let model = defaultProvider?.default_model_name || "";
|
||||
|
||||
if (persona) {
|
||||
// Map "provider override" to actual LLLMProvider
|
||||
if (persona.llm_model_provider_override) {
|
||||
// Map "model override" to actual model
|
||||
if (persona.default_model_configuration_id) {
|
||||
const underlyingProvider = llmProviders.find(
|
||||
(item: LLMProviderDescriptor) =>
|
||||
item.name === persona.llm_model_provider_override
|
||||
item.model_configurations.find(
|
||||
(m) => m.id === persona.default_model_configuration_id
|
||||
)
|
||||
);
|
||||
const underlyingModel = underlyingProvider?.model_configurations.find(
|
||||
(m) => m.id === persona.default_model_configuration_id
|
||||
);
|
||||
provider = underlyingProvider?.provider || provider;
|
||||
model = underlyingModel?.name || model;
|
||||
}
|
||||
model = persona.llm_model_version_override || model;
|
||||
}
|
||||
|
||||
if (currentLlm) {
|
||||
@@ -41,26 +46,27 @@ export function getLLMProviderOverrideForPersona(
|
||||
liveAssistant: MinimalPersonaSnapshot,
|
||||
llmProviders: LLMProviderDescriptor[]
|
||||
): LlmDescriptor | null {
|
||||
const overrideProvider = liveAssistant.llm_model_provider_override;
|
||||
const overrideModel = liveAssistant.llm_model_version_override;
|
||||
const defaultModelConfigurationId =
|
||||
liveAssistant.default_model_configuration_id;
|
||||
|
||||
if (!overrideModel) {
|
||||
if (!defaultModelConfigurationId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const matchingProvider = llmProviders.find(
|
||||
(provider) =>
|
||||
(overrideProvider ? provider.name === overrideProvider : true) &&
|
||||
provider.model_configurations
|
||||
.map((modelConfiguration) => modelConfiguration.name)
|
||||
.includes(overrideModel)
|
||||
const matchingProvider = llmProviders.find((provider) =>
|
||||
provider.model_configurations.find(
|
||||
(m) => m.id === defaultModelConfigurationId
|
||||
)
|
||||
);
|
||||
const underlyingModel = matchingProvider?.model_configurations.find(
|
||||
(m) => m.id === defaultModelConfigurationId
|
||||
);
|
||||
|
||||
if (matchingProvider) {
|
||||
if (matchingProvider && underlyingModel) {
|
||||
return {
|
||||
name: matchingProvider.name,
|
||||
provider: matchingProvider.provider,
|
||||
modelName: overrideModel,
|
||||
modelName: underlyingModel.name,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -451,37 +451,59 @@ export default function AgentEditorPage({
|
||||
const deleteAgentModal = useCreateModal();
|
||||
|
||||
// LLM Model Selection
|
||||
const getCurrentLlm = useCallback(
|
||||
(values: any, llmProviders: any) =>
|
||||
values.llm_model_version_override && values.llm_model_provider_override
|
||||
? (() => {
|
||||
const provider = llmProviders?.find(
|
||||
(p: any) => p.name === values.llm_model_provider_override
|
||||
);
|
||||
return structureValue(
|
||||
values.llm_model_provider_override,
|
||||
provider?.provider || "",
|
||||
values.llm_model_version_override
|
||||
);
|
||||
})()
|
||||
: null,
|
||||
const findProviderAndModel = useCallback(
|
||||
(
|
||||
providers: any[] | undefined,
|
||||
modelPredicate: (m: any) => boolean,
|
||||
providerPredicate?: (p: any) => boolean
|
||||
) => {
|
||||
const candidateProviders = providerPredicate
|
||||
? providers?.filter(providerPredicate)
|
||||
: providers;
|
||||
const provider = candidateProviders?.find(
|
||||
(p: any) => p.model_configurations?.find(modelPredicate)
|
||||
);
|
||||
const model = provider?.model_configurations?.find(modelPredicate);
|
||||
return { provider, model };
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const getCurrentLlm = useCallback(
|
||||
(values: any, llmProviders: any) =>
|
||||
values.default_model_configuration_id
|
||||
? (() => {
|
||||
const { provider, model } = findProviderAndModel(
|
||||
llmProviders,
|
||||
(m: any) => m.id === values.default_model_configuration_id
|
||||
);
|
||||
return structureValue(
|
||||
provider?.name || "",
|
||||
provider?.provider || "",
|
||||
model?.name || ""
|
||||
);
|
||||
})()
|
||||
: null,
|
||||
[findProviderAndModel]
|
||||
);
|
||||
|
||||
const onLlmSelect = useCallback(
|
||||
(selected: string | null, setFieldValue: any) => {
|
||||
(selected: string | null, setFieldValue: any, llmProviders: any) => {
|
||||
if (selected === null) {
|
||||
setFieldValue("llm_model_version_override", null);
|
||||
setFieldValue("llm_model_provider_override", null);
|
||||
setFieldValue("default_model_configuration_id", null);
|
||||
} else {
|
||||
const { modelName, name } = parseLlmDescriptor(selected);
|
||||
if (modelName && name) {
|
||||
setFieldValue("llm_model_version_override", modelName);
|
||||
setFieldValue("llm_model_provider_override", name);
|
||||
const { model } = findProviderAndModel(
|
||||
llmProviders,
|
||||
(m: any) => m.name === modelName,
|
||||
(p: any) => p.name === name
|
||||
);
|
||||
setFieldValue("default_model_configuration_id", model?.id || null);
|
||||
}
|
||||
}
|
||||
},
|
||||
[]
|
||||
[findProviderAndModel]
|
||||
);
|
||||
|
||||
// Hooks for Knowledge section
|
||||
@@ -576,10 +598,8 @@ export default function AgentEditorPage({
|
||||
selected_sources: [] as ValidSources[],
|
||||
|
||||
// Advanced
|
||||
llm_model_provider_override:
|
||||
existingAgent?.llm_model_provider_override ?? null,
|
||||
llm_model_version_override:
|
||||
existingAgent?.llm_model_version_override ?? null,
|
||||
default_model_configuration_id:
|
||||
existingAgent?.default_model_configuration_id ?? null,
|
||||
knowledge_cutoff_date: existingAgent?.search_start_date
|
||||
? new Date(existingAgent.search_start_date)
|
||||
: null,
|
||||
@@ -701,8 +721,7 @@ export default function AgentEditorPage({
|
||||
),
|
||||
|
||||
// Advanced
|
||||
llm_model_provider_override: Yup.string().nullable().optional(),
|
||||
llm_model_version_override: Yup.string().nullable().optional(),
|
||||
default_model_configuration_id: Yup.number().nullable().optional(),
|
||||
knowledge_cutoff_date: Yup.date().nullable().optional(),
|
||||
replace_base_system_prompt: Yup.boolean(),
|
||||
reminders: Yup.string().optional(),
|
||||
@@ -805,8 +824,8 @@ export default function AgentEditorPage({
|
||||
// recency_bias: ...,
|
||||
// llm_filter_extraction: ...,
|
||||
llm_relevance_filter: false,
|
||||
llm_model_provider_override: values.llm_model_provider_override || null,
|
||||
llm_model_version_override: values.llm_model_version_override || null,
|
||||
default_model_configuration_id:
|
||||
values.default_model_configuration_id || null,
|
||||
starter_messages: finalStarterMessages,
|
||||
users: values.shared_user_ids,
|
||||
groups: values.shared_group_ids,
|
||||
@@ -1407,7 +1426,11 @@ export default function AgentEditorPage({
|
||||
llmProviders
|
||||
)}
|
||||
onSelect={(selected) =>
|
||||
onLlmSelect(selected, setFieldValue)
|
||||
onLlmSelect(
|
||||
selected,
|
||||
setFieldValue,
|
||||
llmProviders
|
||||
)
|
||||
}
|
||||
/>
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
Reference in New Issue
Block a user