Compare commits

...

16 Commits

Author SHA1 Message Date
Batuhan Usluel
931988145d revert unnneeded change 2025-12-05 16:03:02 -08:00
Batuhan Usluel
9dc554f2a1 Cleanup PR 2025-12-05 14:29:39 -08:00
Batuhan Usluel
647addb026 Simplify multi model calling change 2025-12-05 13:47:57 -08:00
Batuhan Usluel
d2c6150f86 Remove unneeded comments 2025-12-05 13:31:46 -08:00
Batuhan Usluel
ec571c0ea5 Remove unneeded comments 2025-12-05 13:25:36 -08:00
Batuhan Usluel
7a437ebb81 Refactor process message for multile model handling 2025-12-05 11:01:31 -08:00
Batuhan Usluel
cf5f7e0936 Show latest tab selected for model selection 2025-12-05 10:17:36 -08:00
Batuhan Usluel
6c299ccec5 Remove unneeded response group id 2025-12-05 10:08:02 -08:00
Batuhan Usluel
1fc4c3a930 Show signifier when a model is streaming / done 2025-12-05 08:58:13 -08:00
Batuhan Usluel
89d6e04938 Branching based on model selection 2025-12-04 20:19:57 -08:00
Batuhan Usluel
8bf3559bbf Stream responses concurrently 2025-12-04 18:32:24 -08:00
Batuhan Usluel
1f9024c0b3 cleanup changes 2025-12-04 17:32:25 -08:00
Batuhan Usluel
9f0bbf0e17 Fix frontend message count 2025-12-04 17:23:51 -08:00
Batuhan Usluel
0aef537ba5 Backend + frontend updates for making model selection work (needs heavy cleanup) 2025-12-04 17:15:10 -08:00
Batuhan Usluel
bd918980a8 DB and model updates for multiple llm selection 2025-12-04 14:56:50 -08:00
Batuhan Usluel
e6e42fdbf6 Multiple models frontend draft 2025-12-04 14:47:24 -08:00
19 changed files with 1481 additions and 282 deletions

View File

@@ -0,0 +1,36 @@
"""add multi-modal response support to chat_message
Revision ID: 34ef1e82a4fa
Revises: e8f0d2a38171
Create Date: 2025-12-04 14:53:05.821715
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "34ef1e82a4fa"
down_revision = "e8f0d2a38171"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add model_provider column to track which LLM provider generated the response
op.add_column(
"chat_message",
sa.Column("model_provider", sa.String(), nullable=True),
)
# Add model_name column to track which specific model generated the response
op.add_column(
"chat_message",
sa.Column("model_name", sa.String(), nullable=True),
)
def downgrade() -> None:
# Drop the columns in reverse order
op.drop_column("chat_message", "model_name")
op.drop_column("chat_message", "model_provider")

View File

@@ -97,6 +97,8 @@ class OnyxAnswerPiece(BaseModel):
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
model_provider: str | None = None
model_name: str | None = None
class StreamingError(BaseModel):

View File

@@ -0,0 +1,277 @@
"""
Multi-model streaming infrastructure for concurrent LLM execution.
This module provides classes for running multiple LLM models concurrently
and merging their streaming outputs into a single stream.
"""
import contextvars
import threading
from collections.abc import Callable
from collections.abc import Iterator
from dataclasses import dataclass
from dataclasses import field
from queue import Empty
from queue import Queue
from typing import Any
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.emitter import Emitter
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.utils.logger import setup_logger
logger = setup_logger()
@dataclass
class ModelStreamContext:
"""Context for a single model's streaming execution."""
model_id: str
emitter: "ModelTaggingEmitter"
state_container: ChatStateContainer
thread: threading.Thread | None = None
completed: bool = False
error: Exception | None = None
# Additional data that may be needed for saving chat turns
extra_data: dict[str, Any] = field(default_factory=dict)
class ModelTaggingEmitter(Emitter):
"""Emitter that tags packets with model_id and forwards to a merged queue.
This emitter wraps the standard Emitter to add model identification to each
packet, enabling the frontend to route packets to the correct model's UI.
"""
def __init__(
self,
model_id: str,
merged_queue: Queue[Packet | None],
merger: "MultiModelStreamMerger",
):
# Create a local bus for compatibility with existing code that may
# access emitter.bus directly
super().__init__(bus=Queue())
self.model_id = model_id
self.merged_queue = merged_queue
self.merger = merger
def emit(self, packet: Packet) -> None:
"""Emit a packet tagged with this emitter's model_id.
The packet is forwarded to the merged queue for interleaved streaming.
"""
# Create a new packet with the model_id set
tagged_packet = Packet(
turn_index=packet.turn_index,
obj=packet.obj,
model_id=self.model_id,
)
# Forward to merged queue for interleaved streaming
self.merged_queue.put(tagged_packet)
# Check for completion signals
if isinstance(packet.obj, OverallStop):
self.merger.mark_model_complete(self.model_id)
class MultiModelStreamMerger:
"""Merges packet streams from multiple concurrent LLM executions.
This class manages the concurrent execution of multiple LLM models and
merges their streaming outputs into a single stream. Each model runs in
its own thread and emits packets to a shared queue.
Usage:
merger = MultiModelStreamMerger()
# Register models and get their emitters
for model in models:
ctx = merger.register_model(model_id)
# Start thread with ctx.emitter
# Stream merged packets
for packet in merger.stream(is_connected):
yield packet
"""
def __init__(self) -> None:
self.merged_queue: Queue[Packet | None] = Queue()
self.model_contexts: dict[str, ModelStreamContext] = {}
self._lock = threading.Lock()
self._completed_count = 0
self._total_models = 0
self._all_complete = threading.Event()
def register_model(self, model_id: str) -> ModelStreamContext:
"""Register a model for concurrent streaming.
Args:
model_id: Unique identifier for the model (e.g., "openai:gpt-4")
Returns:
ModelStreamContext containing the emitter to use for this model.
"""
with self._lock:
if model_id in self.model_contexts:
raise ValueError(f"Model {model_id} already registered")
emitter = ModelTaggingEmitter(
model_id=model_id,
merged_queue=self.merged_queue,
merger=self,
)
state_container = ChatStateContainer()
context = ModelStreamContext(
model_id=model_id,
emitter=emitter,
state_container=state_container,
)
self.model_contexts[model_id] = context
self._total_models += 1
return context
def mark_model_complete(
self, model_id: str, error: Exception | None = None
) -> None:
"""Mark a model's stream as complete.
Called automatically when an OverallStop packet is emitted,
or can be called manually on error.
Args:
model_id: The model that has completed
error: Optional exception if the model failed
"""
with self._lock:
if model_id in self.model_contexts:
ctx = self.model_contexts[model_id]
if not ctx.completed:
ctx.completed = True
ctx.error = error
self._completed_count += 1
logger.debug(
f"Model {model_id} completed "
f"({self._completed_count}/{self._total_models})"
)
if self._completed_count >= self._total_models:
# All models complete - send sentinel to unblock stream
self.merged_queue.put(None)
self._all_complete.set()
def start_model_thread(
self,
model_id: str,
target_func: Callable[..., None],
*args: Any,
**kwargs: Any,
) -> None:
"""Start a background thread for a model's LLM execution.
Args:
model_id: The model to run
target_func: The function to run (should accept emitter as first arg)
*args: Additional positional arguments for target_func
**kwargs: Additional keyword arguments for target_func
"""
ctx = self.model_contexts.get(model_id)
if ctx is None:
raise ValueError(f"Model {model_id} not registered")
# Copy context vars for the new thread (important for tenant_id, etc.)
context = contextvars.copy_context()
def thread_target() -> None:
try:
context.run(
target_func,
ctx.emitter,
*args,
state_container=ctx.state_container,
**kwargs,
)
except Exception as e:
logger.exception(f"Model {model_id} failed: {e}")
# Emit error packet
ctx.emitter.emit(
Packet(
turn_index=0,
obj=PacketException(type="error", exception=e),
)
)
self.mark_model_complete(model_id, error=e)
thread = threading.Thread(target=thread_target, daemon=True)
ctx.thread = thread
thread.start()
def stream(
self,
is_connected: Callable[[], bool],
poll_timeout: float = 0.3,
) -> Iterator[Packet]:
"""Yield merged packets from all models.
This generator yields packets from all models as they arrive,
checking the stop signal periodically.
Args:
is_connected: Callable that returns False when stop signal is set
poll_timeout: Timeout for queue polling (seconds)
Yields:
Packet objects with model_id set for routing
"""
while True:
try:
packet = self.merged_queue.get(timeout=poll_timeout)
if packet is None:
# All models complete
break
# Check for exception packets - but don't raise, just yield
# This allows other models to continue streaming
if isinstance(packet.obj, PacketException):
# Log the error but continue streaming other models
logger.error(
f"Error from model {packet.model_id}: {packet.obj.exception}"
)
# Still yield the error packet so frontend can show it
yield packet
continue
# Check for OverallStop - this indicates one model finished
# Don't break here, wait for all models
yield packet
except Empty:
# Check stop signal
if not is_connected():
logger.debug("Stop signal detected, stopping all models")
break
# Check if all models completed (defensive)
if self._all_complete.is_set():
break
def wait_for_threads(self, timeout: float | None = None) -> None:
"""Wait for all model threads to complete.
Args:
timeout: Optional timeout per thread (seconds)
"""
for ctx in self.model_contexts.values():
if ctx.thread is not None and ctx.thread.is_alive():
ctx.thread.join(timeout=timeout)
def get_model_contexts(self) -> dict[str, ModelStreamContext]:
"""Get all model contexts for post-processing (e.g., saving chat turns)."""
return self.model_contexts.copy()

View File

@@ -22,6 +22,7 @@ from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import StreamingError
from onyx.chat.multi_model_stream import MultiModelStreamMerger
from onyx.chat.prompt_utils import calculate_reserved_tokens
from onyx.chat.save_chat import save_chat_turn
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
@@ -51,6 +52,7 @@ from onyx.file_store.utils import verify_user_files
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.redis.redis_pool import get_redis_client
@@ -321,10 +323,18 @@ def stream_chat_message_objects(
"Must specify a set of documents for chat or specify search options"
)
# Build the list of LLM overrides to process (multi-model support)
llm_overrides: list[LLMOverride | None] = []
if new_msg_req.llm_overrides and len(new_msg_req.llm_overrides) > 0:
llm_overrides = list(new_msg_req.llm_overrides)
else:
llm_overrides = [new_msg_req.llm_override or chat_session.llm_override]
# Use first model for initial setup (token counting for user message)
llm, fast_llm = get_llms_for_persona(
persona=persona,
user=user,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
llm_override=llm_overrides[0],
additional_headers=litellm_additional_headers,
long_term_logger=long_term_logger,
)
@@ -372,104 +382,14 @@ def stream_chat_message_objects(
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
reserved_token_count = calculate_reserved_tokens(
db_session=db_session,
persona_system_prompt=custom_agent_prompt or "",
token_counter=token_counter,
files=last_chat_message.files,
memories=memories,
)
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
extracted_project_files = _extract_project_file_texts_and_images(
project_id=chat_session.project_id,
user_id=user_id,
llm_max_context_window=llm.config.max_input_tokens,
reserved_token_count=reserved_token_count,
db_session=db_session,
)
# There are cases where the internal search tool should be disabled
# If the user is in a project, it should not use other sources / generic search
# If they are in a project but using a custom agent, it should use the agent setup
# (which means it can use search)
# However if in a project and there are more files than can fit in the context,
# it should use the search tool with the project filter on
disable_internal_search = bool(
chat_session.project_id
and persona.id is DEFAULT_PERSONA_ID
and (
extracted_project_files.project_file_texts
or not extracted_project_files.project_as_filter
)
)
emitter = get_default_emitter()
# Construct tools based on the persona configurations
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=emitter,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=user_selected_filters,
project_id=(
chat_session.project_id
if extracted_project_files.project_as_filter
else None
),
bypass_acl=bypass_acl,
slack_context=slack_context,
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
additional_headers=custom_tool_additional_headers,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
disable_internal_search=disable_internal_search,
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# TODO Once summarization is done, we don't need to load all the files from the beginning anymore.
# load all files needed for this chat chain in memory
files = load_all_chat_files(chat_history, db_session)
# TODO Need to think of some way to support selected docs from the sidebar
# Reserve a message id for the assistant response for frontend to track packets
assistant_response = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id,
message_type=MessageType.ASSISTANT,
)
yield MessageResponseIDInfo(
user_message_id=user_message.id,
reserved_assistant_message_id=assistant_response.id,
)
# Build a mapping of tool_id to tool_name for history reconstruction
all_tools = get_tools(db_session)
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
# Convert the chat history into a simple format that is free of any DB objects
# and is easy to parse for the agent loop
simple_chat_history = convert_chat_history(
chat_history=chat_history,
files=files,
project_image_files=extracted_project_files.project_image_files,
additional_context=additional_context,
token_counter=token_counter,
tool_id_to_name_map=tool_id_to_name_map,
)
redis_client = get_redis_client()
reset_cancel_status(
@@ -480,76 +400,242 @@ def stream_chat_message_objects(
def check_is_connected() -> bool:
return check_stop_signal(chat_session_id, redis_client)
# Create state container for accumulating partial results
state_container = ChatStateContainer()
# Setup for multi-model (concurrent) vs single-model (synchronous)
is_multi_model = len(llm_overrides) > 1
merger = MultiModelStreamMerger() if is_multi_model else None
# Run the LLM loop with explicit wrapper for stop signal handling
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
# for stop signals. run_llm_loop itself doesn't know about stopping.
# Note: DB session is not thread safe but nothing else uses it and the
# reference is passed directly so it's ok.
yield from run_chat_llm_with_state_containers(
run_llm_loop,
emitter=emitter,
state_container=state_container,
is_connected=check_is_connected, # Not passed through to run_llm_loop
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
project_files=extracted_project_files,
persona=persona,
memories=memories,
llm=llm,
token_counter=token_counter,
db_session=db_session,
forced_tool_id=(
new_msg_req.forced_tool_ids[0] if new_msg_req.forced_tool_ids else None
),
)
# Track state for single-model case (used after the loop)
single_model_state: tuple[ChatStateContainer, ChatMessage, LLM] | None = None
# Process each model
for llm_override in llm_overrides:
# Get LLM for this model
llm, fast_llm = get_llms_for_persona(
persona=persona,
user=user,
llm_override=llm_override,
additional_headers=litellm_additional_headers,
long_term_logger=long_term_logger,
)
token_counter = get_llm_token_counter(llm)
reserved_token_count = calculate_reserved_tokens(
db_session=db_session,
persona_system_prompt=custom_agent_prompt or "",
token_counter=token_counter,
files=last_chat_message.files,
memories=memories,
)
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
extracted_project_files = _extract_project_file_texts_and_images(
project_id=chat_session.project_id,
user_id=user_id,
llm_max_context_window=llm.config.max_input_tokens,
reserved_token_count=reserved_token_count,
db_session=db_session,
)
# There are cases where the internal search tool should be disabled
# If the user is in a project, it should not use other sources / generic search
# If they are in a project but using a custom agent, it should use the agent setup
# (which means it can use search)
# However if in a project and there are more files than can fit in the context,
# it should use the search tool with the project filter on
disable_internal_search = bool(
chat_session.project_id
and persona.id is DEFAULT_PERSONA_ID
and (
extracted_project_files.project_file_texts
or not extracted_project_files.project_as_filter
)
)
# Get the appropriate emitter
if merger is not None:
model_id = f"{llm.config.model_provider}:{llm.config.model_name}"
model_ctx = merger.register_model(model_id)
emitter = model_ctx.emitter
else:
emitter = get_default_emitter()
# Construct tools based on the persona configurations
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=emitter,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=user_selected_filters,
project_id=(
chat_session.project_id
if extracted_project_files.project_as_filter
else None
),
bypass_acl=bypass_acl,
slack_context=slack_context,
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
additional_headers=custom_tool_additional_headers,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
disable_internal_search=disable_internal_search,
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# Reserve a message id for the assistant response for frontend to track packets
assistant_response = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id,
message_type=MessageType.ASSISTANT,
model_provider=llm.config.model_provider,
model_name=llm.config.model_name,
)
yield MessageResponseIDInfo(
user_message_id=user_message.id,
reserved_assistant_message_id=assistant_response.id,
model_provider=llm.config.model_provider,
model_name=llm.config.model_name,
)
# Convert the chat history into a simple format that is free of any DB objects
# and is easy to parse for the agent loop
simple_chat_history = convert_chat_history(
chat_history=chat_history,
files=files,
project_image_files=extracted_project_files.project_image_files,
additional_context=additional_context,
token_counter=token_counter,
tool_id_to_name_map=tool_id_to_name_map,
)
if merger is not None:
# Multi-model: start thread for concurrent execution
model_ctx.extra_data["assistant_response"] = assistant_response
model_ctx.extra_data["llm"] = llm
merger.start_model_thread(
model_id=model_id,
target_func=run_llm_loop,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
project_files=extracted_project_files,
persona=persona,
memories=memories,
llm=llm,
token_counter=token_counter,
db_session=db_session,
forced_tool_id=(
new_msg_req.forced_tool_ids[0]
if new_msg_req.forced_tool_ids
else None
),
)
else:
# Single model: run synchronously
state_container = ChatStateContainer()
single_model_state = (state_container, assistant_response, llm)
# Run the LLM loop with explicit wrapper for stop signal handling
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
# for stop signals. run_llm_loop itself doesn't know about stopping.
# Note: DB session is not thread safe but nothing else uses it and the
# reference is passed directly so it's ok.
yield from run_chat_llm_with_state_containers(
run_llm_loop,
emitter=emitter,
state_container=state_container,
is_connected=check_is_connected, # Not passed through to run_llm_loop
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
project_files=extracted_project_files,
persona=persona,
memories=memories,
llm=llm,
token_counter=token_counter,
db_session=db_session,
forced_tool_id=(
new_msg_req.forced_tool_ids[0]
if new_msg_req.forced_tool_ids
else None
),
)
# After loop: handle streaming for multi-model
if merger is not None:
yield from merger.stream(is_connected=check_is_connected)
merger.wait_for_threads()
# Determine if stopped by user
completed_normally = check_is_connected()
if not completed_normally:
logger.debug(f"Chat session {chat_session_id} stopped by user")
# Build final answer based on completion status
if completed_normally:
if state_container.answer_tokens is None:
raise RuntimeError(
"LLM run completed normally but did not return an answer."
)
final_answer = state_container.answer_tokens
# Build list of (state_container, assistant_response, error) for saving
if merger is not None:
models_to_save = [
(ctx.state_container, ctx.extra_data["assistant_response"], ctx.error)
for ctx in merger.get_model_contexts().values()
]
else:
# Stopped by user - append stop message
if state_container.answer_tokens:
final_answer = (
state_container.answer_tokens
+ " ... The generation was stopped by the user here."
)
assert single_model_state is not None
state_container, assistant_response, _ = single_model_state
models_to_save = [(state_container, assistant_response, None)]
# Save chat turns for all models
for state_container, assistant_response, error in models_to_save:
if error is not None:
final_answer = f"Error: {str(error)}"
elif completed_normally:
if state_container.answer_tokens is None:
if merger is not None:
final_answer = "No response generated."
else:
raise RuntimeError(
"LLM run completed normally but did not return an answer."
)
else:
final_answer = state_container.answer_tokens
else:
final_answer = "The generation was stopped by the user."
# Build citation_docs_info from accumulated citations in state container
citation_docs_info: list[CitationDocInfo] = []
seen_citation_nums: set[int] = set()
for citation_num, search_doc in state_container.citation_to_doc.items():
if citation_num not in seen_citation_nums:
seen_citation_nums.add(citation_num)
citation_docs_info.append(
CitationDocInfo(
search_doc=search_doc,
citation_number=citation_num,
# Stopped by user - append stop message
if state_container.answer_tokens:
final_answer = (
state_container.answer_tokens
+ " ... The generation was stopped by the user here."
)
)
else:
final_answer = "The generation was stopped by the user."
save_chat_turn(
message_text=final_answer,
reasoning_tokens=state_container.reasoning_tokens,
citation_docs_info=citation_docs_info,
tool_calls=state_container.tool_calls,
db_session=db_session,
assistant_message=assistant_response,
)
# Build citation_docs_info from accumulated citations
citation_docs_info: list[CitationDocInfo] = []
seen_citation_nums: set[int] = set()
for citation_num, search_doc in state_container.citation_to_doc.items():
if citation_num not in seen_citation_nums:
seen_citation_nums.add(citation_num)
citation_docs_info.append(
CitationDocInfo(
search_doc=search_doc,
citation_number=citation_num,
)
)
save_chat_turn(
message_text=final_answer,
reasoning_tokens=state_container.reasoning_tokens,
citation_docs_info=citation_docs_info,
tool_calls=state_container.tool_calls,
db_session=db_session,
assistant_message=assistant_response,
)
except ValueError as e:
logger.exception("Failed to process chat message.")

View File

@@ -620,6 +620,8 @@ def reserve_message_id(
chat_session_id: UUID,
parent_message: int,
message_type: MessageType = MessageType.ASSISTANT,
model_provider: str | None = None,
model_name: str | None = None,
) -> ChatMessage:
# Create an temporary holding chat message to the updated and saved at the end
empty_message = ChatMessage(
@@ -629,6 +631,8 @@ def reserve_message_id(
message="Response was termination prior to completion, try regenerating.",
token_count=15,
message_type=message_type,
model_provider=model_provider,
model_name=model_name,
)
# Add the empty message to the session
@@ -661,6 +665,8 @@ def create_new_chat_message(
commit: bool = True,
reserved_message_id: int | None = None,
reasoning_tokens: str | None = None,
model_provider: str | None = None,
model_name: str | None = None,
) -> ChatMessage:
if reserved_message_id is not None:
# Edit existing message
@@ -676,6 +682,8 @@ def create_new_chat_message(
existing_message.files = files
existing_message.error = error
existing_message.reasoning_tokens = reasoning_tokens
existing_message.model_provider = model_provider
existing_message.model_name = model_name
new_chat_message = existing_message
else:
# Create new message
@@ -689,6 +697,8 @@ def create_new_chat_message(
files=files,
error=error,
reasoning_tokens=reasoning_tokens,
model_provider=model_provider,
model_name=model_name,
)
db_session.add(new_chat_message)
@@ -874,6 +884,8 @@ def translate_db_message_to_chat_message_detail(
files=chat_message.files or [],
error=chat_message.error,
current_feedback=current_feedback,
model_provider=chat_message.model_provider,
model_name=chat_message.model_name,
)
return chat_msg_detail

View File

@@ -2142,6 +2142,9 @@ class ChatMessage(Base):
DateTime(timezone=True), server_default=func.now()
)
model_provider: Mapped[str | None] = mapped_column(String, nullable=True)
model_name: Mapped[str | None] = mapped_column(String, nullable=True)
# Relationships
chat_session: Mapped[ChatSession] = relationship("ChatSession")

View File

@@ -109,6 +109,11 @@ class CreateChatMessageRequest(ChunkContext):
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
# List of LLM overrides to generate responses from
# If provided, generates one response per override (all sharing the same parent message)
# Takes precedence over llm_override if both are provided
llm_overrides: list[LLMOverride] | None = None
# Allows the caller to override the temperature for the chat session
# this does persist in the chat thread details
temperature_override: float | None = None
@@ -245,6 +250,9 @@ class ChatMessageDetail(BaseModel):
error: str | None = None
current_feedback: str | None = None # "like" | "dislike" | null
model_provider: str | None = None
model_name: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
initial_dict["time_sent"] = self.time_sent.isoformat()

View File

@@ -260,6 +260,7 @@ PacketObj = Union[
class Packet(BaseModel):
turn_index: int | None
obj: Annotated[PacketObj, Field(discriminator="type")]
model_id: str | None = None # Format: "{provider}:{model_name}" e.g. "openai:gpt-4"
# This is for replaying it back from the DB to the frontend

View File

@@ -10,6 +10,7 @@ import { EnterpriseSettings } from "@/app/admin/settings/interfaces";
import { FileDescriptor } from "@/app/chat/interfaces";
import { MemoizedAIMessage } from "../message/messageComponents/MemoizedAIMessage";
import { ProjectFile } from "../projects/projectsService";
import { ModelResponse } from "../message/messageComponents/ModelResponseTabs";
interface MessagesDisplayProps {
messageHistory: Message[];
@@ -76,6 +77,72 @@ export const MessagesDisplay: React.FC<MessagesDisplayProps> = ({
// Stable fallbacks to avoid changing prop identities on each render
const emptyDocs = useMemo<OnyxDocument[]>(() => [], []);
const emptyChildrenIds = useMemo<number[]>(() => [], []);
// Build a map of parentNodeId -> sibling assistant Messages for multi-model/regeneration grouping
// Uses completeMessageTree to know about ALL sibling groups across all branches
// Also track which nodeIds should be skipped (all but the first in each group)
const { siblingGroupMap, nodeIdsToSkip } = useMemo(() => {
const groupMap = new Map<number, Message[]>();
const skipNodeIds = new Set<number>();
// First, build group info from the COMPLETE message tree (all branches)
// This ensures we know about sibling groups even when viewing a different branch
if (completeMessageTree) {
for (const msg of Array.from(completeMessageTree.values())) {
if (msg.type === "assistant" && msg.parentNodeId !== null) {
const existing = groupMap.get(msg.parentNodeId);
if (existing) {
existing.push(msg);
} else {
groupMap.set(msg.parentNodeId, [msg]);
}
}
}
}
// Also include messages from messageHistory (for streaming/pre-created nodes
// that might not be in completeMessageTree yet)
for (const msg of messageHistory) {
if (msg.type === "assistant" && msg.parentNodeId !== null) {
const existing = groupMap.get(msg.parentNodeId);
if (existing) {
// Check if this message is already in the group (by nodeId)
if (!existing.some((m) => m.nodeId === msg.nodeId)) {
existing.push(msg);
}
} else {
groupMap.set(msg.parentNodeId, [msg]);
}
}
}
// For each group with multiple siblings, sort for consistent ordering
// Primary sort: by model name (for consistent order after refresh)
// Fallback sort: by nodeId (for stable order during streaming when model names aren't set yet)
for (const [, messages] of Array.from(groupMap.entries())) {
if (messages.length > 1) {
messages.sort((a: Message, b: Message) => {
const aName = `${a.modelProvider || ""}:${a.modelName || ""}`;
const bName = `${b.modelProvider || ""}:${b.modelName || ""}`;
// If both have model names, sort by model name
// Otherwise, fall back to nodeId for stable ordering during streaming
if (aName !== ":" && bName !== ":") {
return aName.localeCompare(bName);
}
return a.nodeId - b.nodeId;
});
// Skip all except the first (representative) message
for (let i = 1; i < messages.length; i++) {
const msg = messages[i];
if (msg) {
skipNodeIds.add(msg.nodeId);
}
}
}
}
return { siblingGroupMap: groupMap, nodeIdsToSkip: skipNodeIds };
}, [messageHistory, completeMessageTree]);
const createRegenerator = useCallback(
(regenerationRequest: {
messageId: number;
@@ -172,10 +239,56 @@ export const MessagesDisplay: React.FC<MessagesDisplayProps> = ({
);
}
// NOTE: it's fine to use the previous entry in messageHistory
// since this is a "parsed" version of the message tree
// so the previous message is guaranteed to be the parent of the current message
const previousMessage = i !== 0 ? messageHistory[i - 1] : null;
// For assistant messages, we need to find the actual parent (user message)
// by looking up parentNodeId, since messageHistory may contain sibling
// assistant messages (multi-model or regenerations) that share the same parent
const previousMessage = (() => {
if (i === 0) return null;
// For assistant messages, find the parent by parentNodeId
if (message.type === "assistant" && message.parentNodeId !== null) {
return (
messageHistory.find((m) => m.nodeId === message.parentNodeId) ??
null
);
}
// For user messages, the previous entry is the parent
return messageHistory[i - 1] ?? null;
})();
// Multi-model/regeneration sibling grouping:
// Skip messages that are not the first in their sibling group
if (nodeIdsToSkip.has(message.nodeId)) {
return null;
}
// Build modelResponses from all sibling messages with the same parent
let modelResponses: ModelResponse[] | undefined;
if (message.type === "assistant" && message.parentNodeId !== null) {
const siblingMessages = siblingGroupMap.get(message.parentNodeId);
if (siblingMessages && siblingMessages.length > 1) {
modelResponses = siblingMessages.map((msg) => ({
model: {
name: msg.modelProvider || "",
provider: msg.modelProvider || "",
modelName: msg.modelName || "",
},
// Include the actual message for this model's response
message: msg,
}));
}
}
// Filter out non-representative messages from ALL sibling groups for branch switching
// Each multi-model/regeneration group should appear as a single branch option
const switchableMessages = (() => {
const allChildren =
parentMessage?.childrenNodeIds ?? emptyChildrenIds;
// Filter out all nodeIds that are "non-representative" in their sibling group
// nodeIdsToSkip contains all messages except the first one in each group
return allChildren.filter((nodeId) => !nodeIdsToSkip.has(nodeId));
})();
return (
<div
className="text-text"
@@ -196,11 +309,11 @@ export const MessagesDisplay: React.FC<MessagesDisplayProps> = ({
overriddenModel={llmManager.currentLlm?.modelName}
nodeId={message.nodeId}
llmManager={llmManager}
otherMessagesCanSwitchTo={
parentMessage?.childrenNodeIds ?? emptyChildrenIds
}
otherMessagesCanSwitchTo={switchableMessages}
onMessageSelection={onMessageSelection}
researchType={message.researchType}
modelResponses={modelResponses}
latestChildNodeId={parentMessage?.latestChildNodeId}
/>
</div>
);

View File

@@ -17,7 +17,6 @@ import {
MessageTreeState,
upsertMessages,
SYSTEM_NODE_ID,
buildImmediateMessages,
buildEmptyMessage,
} from "../services/messageTree";
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
@@ -592,69 +591,74 @@ export function useChatController({
// Add user message immediately to the message tree so that the chat
// immediately reflects the user message
// Assistant nodes are created on-demand when MessageResponseIDInfo arrives
let initialUserNode: Message;
let initialAssistantNode: Message;
if (regenerationRequest) {
// For regeneration: keep the existing user message, only create new assistant
// For regeneration: keep the existing user message
initialUserNode = regenerationRequest.parentMessage;
initialAssistantNode = buildEmptyMessage({
messageType: "assistant",
parentNodeId: initialUserNode.nodeId,
nodeIdOffset: 1,
});
} else {
// For new messages or editing: create/update user message and assistant
// For new messages or editing: create/update user message
const parentNodeIdForMessage = messageToResend
? messageToResend.parentNodeId || SYSTEM_NODE_ID
: parentMessage?.nodeId || SYSTEM_NODE_ID;
const result = buildImmediateMessages(
parentNodeIdForMessage,
currMessage,
projectFilesToFileDescriptors(currentMessageFiles),
messageToResend
);
initialUserNode = result.initialUserNode;
initialAssistantNode = result.initialAssistantNode;
initialUserNode = messageToResend
? { ...messageToResend }
: buildEmptyMessage({
messageType: "user",
parentNodeId: parentNodeIdForMessage,
message: currMessage,
files: projectFilesToFileDescriptors(currentMessageFiles),
});
}
// make messages appear + clear input bar
const messagesToUpsert = regenerationRequest
? [initialAssistantNode] // Only upsert the new assistant for regeneration
: [initialUserNode, initialAssistantNode]; // Upsert both for normal/edit flow
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: messagesToUpsert,
completeMessageTreeOverride: currentMessageTreeLocal,
chatSessionId: frozenSessionId,
});
// Track assistant nodes (supports both single and multi-model)
// Nodes are created on-demand when MessageResponseIDInfo arrives with backend IDs
interface AssistantNodeData {
node: Message;
packets: Packet[];
documents: OnyxDocument[];
citations: CitationMap | null;
finalMessage: BackendMessage | null;
}
// Map model_id (e.g. "openai:gpt-4") to index in assistantNodes array
const modelIdToIndex: Map<string, number> = new Map();
const assistantNodes: AssistantNodeData[] = [];
// Only upsert user message immediately; assistant nodes created when backend IDs arrive
if (!regenerationRequest) {
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: [initialUserNode],
completeMessageTreeOverride: currentMessageTreeLocal,
chatSessionId: frozenSessionId,
});
}
resetInputBar();
let answer = "";
// Shared state across all models
const stopReason: StreamStopReason | null = null;
let query: string | null = null;
let retrievalType: RetrievalType =
selectedDocuments.length > 0
? RetrievalType.SelectedDocs
: RetrievalType.None;
let documents: OnyxDocument[] = selectedDocuments;
let citations: CitationMap | null = null;
let aiMessageImages: FileDescriptor[] | null = null;
let error: string | null = null;
let stackTrace: string | null = null;
let finalMessage: BackendMessage | null = null;
let toolCall: ToolCallMetadata | null = null;
let files = projectFilesToFileDescriptors(currentMessageFiles);
let packets: Packet[] = [];
let newUserMessageId: number | null = null;
let newAssistantMessageId: number | null = null;
try {
const lastSuccessfulMessageId = getLastSuccessfulMessageId(
currentMessageTreeLocal
);
// Read the CURRENT message tree from store to get the correct parent
// This is important for multi-model responses where tab switching updates
// latestChildNodeId - we need the fresh state, not the captured closure
const freshMessageTree =
useChatSessionStore.getState().sessions.get(frozenSessionId)
?.messageTree || currentMessageTreeLocal;
const lastSuccessfulMessageId =
getLastSuccessfulMessageId(freshMessageTree);
const disabledToolIds = liveAssistant
? assistantPreferences?.[liveAssistant?.id]?.disabled_tool_ids
: undefined;
@@ -704,6 +708,14 @@ export function useChatController({
searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
undefined,
temperature: llmManager.temperature || undefined,
// Multi-model support: if multiple LLMs are selected and no single override, send all selected models
llmOverrides:
!modelOverride && llmManager.selectedLlms.length > 1
? llmManager.selectedLlms.map((llm) => ({
model_provider: llm.name,
model_version: llm.modelName,
}))
: undefined,
systemPromptOverride:
searchParams?.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
useExistingUserMessage: isSeededChat,
@@ -746,8 +758,44 @@ export function useChatController({
if (
(packet as MessageResponseIDInfo).reserved_assistant_message_id
) {
newAssistantMessageId = (packet as MessageResponseIDInfo)
.reserved_assistant_message_id;
const msgInfo = packet as MessageResponseIDInfo;
const messageId = msgInfo.reserved_assistant_message_id;
// Create assistant node on-demand with nodeId = messageId
const newNode: Message = {
nodeId: messageId,
messageId: messageId,
message: "",
type: "assistant",
files: [],
toolCall: null,
parentNodeId: initialUserNode.nodeId,
packets: [],
modelProvider: msgInfo.model_provider || undefined,
modelName: msgInfo.model_name || undefined,
};
const nodeIndex = assistantNodes.length;
assistantNodes.push({
node: newNode,
packets: [],
documents: selectedDocuments,
citations: null,
finalMessage: null,
});
// Map model_id to index for packet routing
if (msgInfo.model_provider && msgInfo.model_name) {
const modelId = `${msgInfo.model_provider}:${msgInfo.model_name}`;
modelIdToIndex.set(modelId, nodeIndex);
}
// Upsert the new assistant node to the tree
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: [newNode],
completeMessageTreeOverride: currentMessageTreeLocal,
chatSessionId: frozenSessionId,
});
}
if (Object.hasOwn(packet, "user_files")) {
@@ -782,7 +830,16 @@ export function useChatController({
throw new Error((packet as StreamingError).error);
} else if (Object.hasOwn(packet, "message_id")) {
finalMessage = packet as BackendMessage;
// Route finalMessage to target assistant
const backendMsg = packet as BackendMessage;
const msgModelId = (backendMsg as any).model_id as
| string
| undefined;
const targetIndex = msgModelId
? modelIdToIndex.get(msgModelId) ?? 0
: 0;
const targetNode = assistantNodes[targetIndex];
if (targetNode) targetNode.finalMessage = backendMsg;
} else if (Object.hasOwn(packet, "stop_reason")) {
const stop_reason = (packet as StreamStopInfo).stop_reason;
if (stop_reason === StreamStopReason.CONTEXT_LENGTH) {
@@ -790,29 +847,35 @@ export function useChatController({
}
} else if (Object.hasOwn(packet, "obj")) {
console.debug("Object packet:", JSON.stringify(packet));
packets.push(packet as Packet);
const typedPacket = packet as Packet;
// Check if the packet contains document information
const packetObj = (packet as Packet).obj;
// Route packet to target assistant by model_id, or first assistant if single model
const targetIndex = typedPacket.model_id
? modelIdToIndex.get(typedPacket.model_id) ?? 0
: 0;
const target = assistantNodes[targetIndex];
if (!target) continue;
target.packets.push(typedPacket);
// Update per-assistant state based on packet type
const packetObj = typedPacket.obj;
if (packetObj.type === "citation_info") {
// Individual citation packet from backend streaming
const citationInfo = packetObj as {
type: "citation_info";
citation_number: number;
document_id: string;
};
// Incrementally build citations map
citations = {
...(citations || {}),
target.citations = {
...(target.citations || {}),
[citationInfo.citation_number]: citationInfo.document_id,
};
} else if (packetObj.type === "citation_delta") {
// Batched citation packet (for backwards compatibility)
const citationDelta = packetObj as CitationDelta;
if (citationDelta.citations) {
citations = {
...(citations || {}),
target.citations = {
...(target.citations || {}),
...Object.fromEntries(
citationDelta.citations.map((c) => [
c.citation_num,
@@ -824,10 +887,10 @@ export function useChatController({
} else if (packetObj.type === "message_start") {
const messageStart = packetObj as MessageStart;
if (messageStart.final_documents) {
documents = messageStart.final_documents;
target.documents = messageStart.final_documents;
updateSelectedNodeForDocDisplay(
frozenSessionId,
initialAssistantNode.nodeId
target.node.nodeId
);
}
}
@@ -840,30 +903,42 @@ export function useChatController({
parentMessage =
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
// Build messages to upsert
const messagesToUpsertInLoop: Message[] = [
{
...initialUserNode,
messageId: newUserMessageId ?? undefined,
files: files,
},
];
// Add all assistant nodes
for (const assistantData of assistantNodes) {
messagesToUpsertInLoop.push({
...assistantData.node,
message: error || "",
type: error ? "error" : ("assistant" as const),
retrievalType,
query: assistantData.finalMessage?.rephrased_query || query,
documents: assistantData.documents,
citations:
assistantData.finalMessage?.citations ||
assistantData.citations ||
{},
files:
assistantData.finalMessage?.files || aiMessageImages || [],
toolCall: assistantData.finalMessage?.tool_call || toolCall,
stackTrace: stackTrace,
overridden_model: assistantData.finalMessage?.overridden_model,
stopReason: stopReason,
packets: assistantData.packets,
modelProvider: assistantData.node.modelProvider,
modelName: assistantData.node.modelName,
});
}
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: [
{
...initialUserNode,
messageId: newUserMessageId ?? undefined,
files: files,
},
{
...initialAssistantNode,
messageId: newAssistantMessageId ?? undefined,
message: error || answer,
type: error ? "error" : "assistant",
retrievalType,
query: finalMessage?.rephrased_query || query,
documents: documents,
citations: finalMessage?.citations || citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCall: finalMessage?.tool_call || toolCall,
stackTrace: stackTrace,
overridden_model: finalMessage?.overridden_model,
stopReason: stopReason,
packets: packets,
},
],
messages: messagesToUpsertInLoop,
// Pass the latest map state
completeMessageTreeOverride: currentMessageTreeLocal,
chatSessionId: frozenSessionId!,
@@ -873,6 +948,9 @@ export function useChatController({
} catch (e: any) {
console.log("Error:", e);
const errorMsg = e.message;
// Use existing assistant node if available, otherwise create temp error node
const errorNodeId =
assistantNodes[0]?.node.nodeId ?? -1 * Date.now() - 1;
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: [
{
@@ -890,7 +968,7 @@ export function useChatController({
packets: [],
},
{
nodeId: initialAssistantNode.nodeId,
nodeId: errorNodeId,
message: errorMsg,
type: "error",
files: aiMessageImages || [],

View File

@@ -143,6 +143,9 @@ export interface Message {
// feedback state
currentFeedback?: FeedbackType | null;
modelProvider?: string;
modelName?: string;
}
export interface BackendChatSession {
@@ -192,6 +195,9 @@ export interface BackendMessage {
tool_call: ToolCallFinalResult | null;
current_feedback: string | null;
model_provider: string | null;
model_name: string | null;
sub_questions: SubQuestionDetail[];
// Keeping existing properties
comments: any;
@@ -203,6 +209,8 @@ export interface BackendMessage {
export interface MessageResponseIDInfo {
user_message_id: number | null;
reserved_assistant_message_id: number;
model_provider?: string | null;
model_name?: string | null;
}
export interface UserKnowledgeFilePacket {

View File

@@ -13,7 +13,14 @@ import { FeedbackType } from "@/app/chat/interfaces";
import { OnyxDocument } from "@/lib/search/interfaces";
import CitedSourcesToggle from "@/app/chat/message/messageComponents/CitedSourcesToggle";
import { TooltipGroup } from "@/components/tooltip/CustomTooltip";
import { useRef, useState, useEffect, useCallback, RefObject } from "react";
import {
useRef,
useState,
useEffect,
useCallback,
useMemo,
RefObject,
} from "react";
import {
useChatSessionStore,
useDocumentSidebarVisible,
@@ -46,6 +53,11 @@ import FeedbackModal, {
} from "../../components/modal/FeedbackModal";
import { usePopup } from "@/components/admin/connectors/Popup";
import { useFeedbackController } from "../../hooks/useFeedbackController";
import {
ModelResponse,
ModelResponseTabs,
useModelResponses,
} from "./ModelResponseTabs";
export interface AIMessageProps {
rawPackets: Packet[];
@@ -56,6 +68,11 @@ export interface AIMessageProps {
llmManager: LlmManager | null;
otherMessagesCanSwitchTo?: number[];
onMessageSelection?: (nodeId: number) => void;
// Multi-model responses: when multiple models are selected, each has its own response
// If undefined or length <= 1, renders normally without tabs
modelResponses?: ModelResponse[];
// The nodeId of the latest/selected child in the sibling group (for tab selection)
latestChildNodeId?: number | null;
}
export default function AIMessage({
@@ -67,11 +84,56 @@ export default function AIMessage({
llmManager,
otherMessagesCanSwitchTo,
onMessageSelection,
modelResponses,
latestChildNodeId,
}: AIMessageProps) {
const markdownRef = useRef<HTMLDivElement>(null);
const { popup, setPopup } = usePopup();
const { handleFeedbackChange } = useFeedbackController({ setPopup });
// Multi-model response state
const {
activeIndex: activeModelIndex,
setActiveIndex: setActiveModelIndex,
hasMultipleResponses,
activeResponse,
} = useModelResponses(modelResponses, latestChildNodeId);
// Handler for tab changes - switches branches via onMessageSelection
// This updates latestChildNodeId and persists to backend, enabling true branching
const handleTabChange = useCallback(
(index: number) => {
setActiveModelIndex(index);
// Switch branches via onMessageSelection - this updates the message tree
// and persists to backend so subsequent messages branch from this response
const selectedMessage = modelResponses?.[index]?.message;
if (selectedMessage?.nodeId && onMessageSelection) {
onMessageSelection(selectedMessage.nodeId);
}
},
[setActiveModelIndex, modelResponses, onMessageSelection]
);
// DEBUG: Log modelResponses
if (modelResponses && modelResponses.length > 0) {
console.log(
"[AIMessage] nodeId:",
nodeId,
"modelResponses:",
modelResponses.length,
"hasMultipleResponses:",
hasMultipleResponses
);
}
// When in multi-model mode, use the active response's packets instead of rawPackets
const effectivePackets = useMemo(() => {
if (hasMultipleResponses && activeResponse?.message?.packets) {
return activeResponse.message.packets;
}
return rawPackets;
}, [hasMultipleResponses, activeResponse, rawPackets]);
const modal = useCreateModal();
const [feedbackModalProps, setFeedbackModalProps] =
useState<FeedbackModalProps | null>(null);
@@ -139,7 +201,8 @@ export default function AIMessage({
);
const [finalAnswerComing, _setFinalAnswerComing] = useState(
isFinalAnswerComing(rawPackets) || isStreamingComplete(rawPackets)
isFinalAnswerComing(effectivePackets) ||
isStreamingComplete(effectivePackets)
);
const setFinalAnswerComing = (value: boolean) => {
_setFinalAnswerComing(value);
@@ -147,7 +210,7 @@ export default function AIMessage({
};
const [displayComplete, _setDisplayComplete] = useState(
isStreamingComplete(rawPackets)
isStreamingComplete(effectivePackets)
);
const setDisplayComplete = (value: boolean) => {
_setDisplayComplete(value);
@@ -155,7 +218,7 @@ export default function AIMessage({
};
const [stopPacketSeen, _setStopPacketSeen] = useState(
isStreamingComplete(rawPackets)
isStreamingComplete(effectivePackets)
);
const setStopPacketSeen = (value: boolean) => {
_setStopPacketSeen(value);
@@ -173,12 +236,20 @@ export default function AIMessage({
const groupedPacketsRef = useRef<{ turn_index: number; packets: Packet[] }[]>(
[]
);
const finalAnswerComingRef = useRef<boolean>(isFinalAnswerComing(rawPackets));
const displayCompleteRef = useRef<boolean>(isStreamingComplete(rawPackets));
const stopPacketSeenRef = useRef<boolean>(isStreamingComplete(rawPackets));
const finalAnswerComingRef = useRef<boolean>(
isFinalAnswerComing(effectivePackets)
);
const displayCompleteRef = useRef<boolean>(
isStreamingComplete(effectivePackets)
);
const stopPacketSeenRef = useRef<boolean>(
isStreamingComplete(effectivePackets)
);
// Track turn_index values for graceful SECTION_END injection
const seenTurnIndicesRef = useRef<Set<number>>(new Set());
const turnIndicesWithSectionEndRef = useRef<Set<number>>(new Set());
// Track previous activeModelIndex for synchronous reset detection
const prevActiveModelIndexRef = useRef<number>(activeModelIndex);
// Reset incremental state when switching messages or when stream resets
const resetState = () => {
@@ -189,9 +260,9 @@ export default function AIMessage({
documentMapRef.current = new Map();
groupedPacketsMapRef.current = new Map();
groupedPacketsRef.current = [];
finalAnswerComingRef.current = isFinalAnswerComing(rawPackets);
displayCompleteRef.current = isStreamingComplete(rawPackets);
stopPacketSeenRef.current = isStreamingComplete(rawPackets);
finalAnswerComingRef.current = isFinalAnswerComing(effectivePackets);
displayCompleteRef.current = isStreamingComplete(effectivePackets);
stopPacketSeenRef.current = isStreamingComplete(effectivePackets);
seenTurnIndicesRef.current = new Set();
turnIndicesWithSectionEndRef.current = new Set();
};
@@ -199,8 +270,19 @@ export default function AIMessage({
resetState();
}, [nodeId]);
// SYNCHRONOUS reset when switching model tabs - must happen BEFORE packet processing
// useEffect runs AFTER render, but packet processing happens DURING render,
// so we need to detect tab changes synchronously to avoid mixing packets
if (
hasMultipleResponses &&
prevActiveModelIndexRef.current !== activeModelIndex
) {
resetState();
prevActiveModelIndexRef.current = activeModelIndex;
}
// If the upstream replaces packets with a shorter list (reset), clear state
if (lastProcessedIndexRef.current > rawPackets.length) {
if (lastProcessedIndexRef.current > effectivePackets.length) {
resetState();
}
@@ -239,9 +321,13 @@ export default function AIMessage({
};
// Process only the new packets synchronously for this render
if (rawPackets.length > lastProcessedIndexRef.current) {
for (let i = lastProcessedIndexRef.current; i < rawPackets.length; i++) {
const packet = rawPackets[i];
if (effectivePackets.length > lastProcessedIndexRef.current) {
for (
let i = lastProcessedIndexRef.current;
i < effectivePackets.length;
i++
) {
const packet = effectivePackets[i];
if (!packet) continue;
const currentTurnIndex = packet.turn_index;
@@ -368,7 +454,7 @@ export default function AIMessage({
.filter(({ packets }) => hasContentPackets(packets))
.sort((a, b) => a.turn_index - b.turn_index);
lastProcessedIndexRef.current = rawPackets.length;
lastProcessedIndexRef.current = effectivePackets.length;
}
const citations = citationsRef.current;
@@ -431,6 +517,26 @@ export default function AIMessage({
<div className="w-full">
<div className="max-w-message-max break-words">
<div className="w-full desktop:ml-4">
{/* Multi-model response tabs */}
{hasMultipleResponses && modelResponses && (
<ModelResponseTabs
modelResponses={modelResponses}
activeIndex={activeModelIndex}
onTabChange={handleTabChange}
/>
)}
{/* Show active model name badge when multiple responses */}
{hasMultipleResponses && activeResponse && (
<div className="text-xs text-text-03 mb-2">
Response from{" "}
<span className="font-medium text-text-02">
{activeResponse.model.modelName ||
activeResponse.model.provider}
</span>
</div>
)}
<div className="max-w-message-max break-words">
<div
ref={markdownRef}
@@ -550,7 +656,9 @@ export default function AIMessage({
)}
<CopyIconButton
getCopyText={() => getTextContent(rawPackets)}
getCopyText={() =>
getTextContent(effectivePackets)
}
tertiary
data-testid="AIMessage/copy-button"
/>
@@ -583,13 +691,20 @@ export default function AIMessage({
<div data-testid="AIMessage/regenerate">
<LLMPopover
llmManager={llmManager}
currentModelName={chatState.overriddenModel}
currentModelName={
// Use the model from the active response (for multi-model)
// or fall back to session-level override
activeResponse?.message?.modelName ||
activeResponse?.model.modelName ||
chatState.overriddenModel
}
onSelect={(modelName) => {
const llmDescriptor =
parseLlmDescriptor(modelName);
chatState.regenerate!(llmDescriptor);
}}
folded
singleSelectMode={true}
/>
</div>
)}

View File

@@ -5,6 +5,7 @@ import { OnyxDocument, MinimalOnyxDocument } from "@/lib/search/interfaces";
import AIMessage from "./AIMessage";
import { LlmDescriptor, LlmManager } from "@/lib/hooks";
import { ProjectFile } from "@/app/chat/projects/projectsService";
import { ModelResponse } from "./ModelResponseTabs";
interface BaseMemoizedAIMessageProps {
rawPackets: any[];
@@ -21,6 +22,10 @@ interface BaseMemoizedAIMessageProps {
llmManager: LlmManager | null;
projectFiles?: ProjectFile[];
researchType?: string | null;
// Multi-model responses
modelResponses?: ModelResponse[];
// The nodeId of the latest/selected child in the sibling group (for tab selection)
latestChildNodeId?: number | null;
}
interface InternalMemoizedAIMessageProps extends BaseMemoizedAIMessageProps {
@@ -54,6 +59,8 @@ const InternalMemoizedAIMessage = React.memo(
llmManager,
projectFiles,
researchType,
modelResponses,
latestChildNodeId,
}: InternalMemoizedAIMessageProps) {
const chatState = React.useMemo(
() => ({
@@ -88,6 +95,8 @@ const InternalMemoizedAIMessage = React.memo(
llmManager={llmManager}
otherMessagesCanSwitchTo={otherMessagesCanSwitchTo}
onMessageSelection={onMessageSelection}
modelResponses={modelResponses}
latestChildNodeId={latestChildNodeId}
/>
);
}
@@ -110,6 +119,8 @@ export const MemoizedAIMessage = ({
llmManager,
projectFiles,
researchType,
modelResponses,
latestChildNodeId,
}: MemoizedAIMessageProps) => {
const regenerate = useMemo(() => {
if (messageId === undefined) {
@@ -145,6 +156,8 @@ export const MemoizedAIMessage = ({
llmManager={llmManager}
projectFiles={projectFiles}
researchType={researchType}
modelResponses={modelResponses}
latestChildNodeId={latestChildNodeId}
/>
);
};

View File

@@ -0,0 +1,154 @@
"use client";
import { useState, useMemo, useEffect, useRef } from "react";
import { LlmDescriptor } from "@/lib/hooks";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import { cn } from "@/lib/utils";
import { isStreamingComplete } from "@/app/chat/services/packetUtils";
import { FiCheck } from "react-icons/fi";
import { Message } from "@/app/chat/interfaces";
export interface ModelResponse {
model: LlmDescriptor;
// The actual message data for this model's response
message?: Message;
}
interface ModelResponseTabsProps {
modelResponses: ModelResponse[];
activeIndex: number;
onTabChange: (index: number) => void;
}
// Streaming indicator component - pulsing dot animation
function StreamingIndicator() {
return (
<span className="relative flex h-2 w-2">
<span className="animate-ping absolute inline-flex h-full w-full rounded-full bg-amber-400 opacity-75" />
<span className="relative inline-flex rounded-full h-2 w-2 bg-amber-500" />
</span>
);
}
// Completion indicator component - checkmark
function CompletedIndicator() {
return (
<span className="flex items-center justify-center h-3.5 w-3.5 rounded-full bg-emerald-500/20">
<FiCheck className="h-2.5 w-2.5 text-emerald-600" strokeWidth={3} />
</span>
);
}
export function ModelResponseTabs({
modelResponses,
activeIndex,
onTabChange,
}: ModelResponseTabsProps) {
if (modelResponses.length <= 1) {
return null;
}
return (
<div className="flex items-center gap-1 mb-3 pb-2 border-b border-border-01">
{modelResponses.map((response, index) => {
const isActive = index === activeIndex;
const Icon = getProviderIcon(
response.model.provider,
response.model.modelName
);
// Determine streaming status for this model's response
const packets = response.message?.packets || [];
const isComplete = isStreamingComplete(packets);
const hasStartedStreaming = packets.length > 0;
return (
<button
key={`${response.model.provider}-${response.model.modelName}-${index}`}
onClick={() => onTabChange(index)}
className={cn(
"flex items-center gap-1.5 px-3 py-1.5 rounded-lg text-sm transition-colors",
isActive
? "bg-background-emphasis text-text-01 font-medium"
: "text-text-03 hover:bg-background-hover hover:text-text-02"
)}
>
<Icon size={16} />
<span className="max-w-[120px] truncate">
{response.model.modelName || response.model.provider}
</span>
{/* Status indicator: streaming or complete */}
{hasStartedStreaming && !isComplete && <StreamingIndicator />}
{isComplete && <CompletedIndicator />}
</button>
);
})}
</div>
);
}
// Hook to manage multi-model response state
export function useModelResponses(
modelResponses?: ModelResponse[],
latestChildNodeId?: number | null
) {
// Calculate initial index based on latestChildNodeId
const initialIndex = useMemo(() => {
if (!modelResponses || modelResponses.length === 0) return 0;
if (latestChildNodeId === undefined || latestChildNodeId === null) return 0;
const index = modelResponses.findIndex(
(r) => r.message?.nodeId === latestChildNodeId
);
return index >= 0 ? index : 0;
}, [modelResponses, latestChildNodeId]);
const [activeIndex, setActiveIndex] = useState(initialIndex);
// Track previous modelResponses length to detect new responses (regeneration)
const prevLengthRef = useRef(modelResponses?.length ?? 0);
// Auto-switch to new tab when a response is added (regeneration case)
// Also sync with latestChildNodeId when it changes (e.g., on load or branch switch)
useEffect(() => {
const currentLength = modelResponses?.length ?? 0;
// If a new response was added, switch to it (regeneration case)
if (currentLength > prevLengthRef.current && currentLength > 0) {
setActiveIndex(currentLength - 1);
}
// If latestChildNodeId changed (e.g., loading chat), sync to it
else if (
latestChildNodeId !== undefined &&
latestChildNodeId !== null &&
modelResponses
) {
const targetIndex = modelResponses.findIndex(
(r) => r.message?.nodeId === latestChildNodeId
);
if (targetIndex >= 0 && targetIndex !== activeIndex) {
setActiveIndex(targetIndex);
}
}
prevLengthRef.current = currentLength;
}, [modelResponses, latestChildNodeId, activeIndex]);
// Reset active index if it's out of bounds
const safeActiveIndex = useMemo(() => {
if (!modelResponses || modelResponses.length === 0) return 0;
return Math.min(activeIndex, modelResponses.length - 1);
}, [activeIndex, modelResponses]);
const hasMultipleResponses = (modelResponses?.length ?? 0) > 1;
const activeResponse = modelResponses?.[safeActiveIndex];
return {
activeIndex: safeActiveIndex,
setActiveIndex,
hasMultipleResponses,
activeResponse,
};
}

View File

@@ -180,6 +180,7 @@ export interface SendMessageParams {
useAgentSearch?: boolean;
enabledToolIds?: number[];
forcedToolIds?: number[];
llmOverrides?: { model_provider: string; model_version: string }[];
}
export async function* sendMessage({
@@ -203,6 +204,7 @@ export async function* sendMessage({
useAgentSearch,
enabledToolIds,
forcedToolIds,
llmOverrides,
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
const documentsAreSelected =
selectedDocumentIds && selectedDocumentIds.length > 0;
@@ -232,14 +234,18 @@ export async function* sendMessage({
system_prompt: systemPromptOverride,
}
: null,
// Multi-model response support: if llmOverrides is provided with multiple models,
// send it as llm_overrides; otherwise use single llm_override for backwards compatibility
llm_override:
temperature || modelVersion
!llmOverrides && (temperature || modelVersion)
? {
temperature,
model_provider: modelProvider,
model_version: modelVersion,
}
: null,
llm_overrides:
llmOverrides && llmOverrides.length > 0 ? llmOverrides : null,
use_existing_user_message: useExistingUserMessage,
use_agentic_search: useAgentSearch ?? false,
allowed_tool_ids: enabledToolIds,
@@ -550,6 +556,8 @@ export function processRawChatHistory(
overridden_model: messageInfo.overridden_model,
packets: packetsForMessage || [],
currentFeedback: messageInfo.current_feedback as FeedbackType | null,
modelProvider: messageInfo.model_provider ?? undefined,
modelName: messageInfo.model_name ?? undefined,
};
messages.set(messageInfo.message_id, message);

View File

@@ -254,6 +254,17 @@ export function getLatestMessageChain(messages: MessageTreeState): Message[] {
return chain;
}
// Build a map of parentNodeId -> sibling assistant messages
// This is used to include all multi-model responses (or regenerations) when we encounter one
const siblingAssistantMap = new Map<number, Message[]>();
for (const msg of Array.from(messages.values())) {
if (msg.type === "assistant" && msg.parentNodeId !== null) {
const existing = siblingAssistantMap.get(msg.parentNodeId) || [];
existing.push(msg);
siblingAssistantMap.set(msg.parentNodeId, existing);
}
}
// Find the root message
let root: Message | undefined;
if (messages.has(SYSTEM_NODE_ID)) {
@@ -285,6 +296,9 @@ export function getLatestMessageChain(messages: MessageTreeState): Message[] {
chain.push(root);
}
// Track which parent nodes we've already added sibling groups for (to avoid duplicates)
const addedSiblingGroups = new Set<number>();
while (
currentMessage?.latestChildNodeId !== null &&
currentMessage?.latestChildNodeId !== undefined
@@ -292,8 +306,46 @@ export function getLatestMessageChain(messages: MessageTreeState): Message[] {
const nextNodeId = currentMessage.latestChildNodeId;
const nextMessage = messages.get(nextNodeId);
if (nextMessage) {
chain.push(nextMessage);
currentMessage = nextMessage;
// Check if this is an assistant message that might have siblings (multi-model or regenerations)
if (
nextMessage.type === "assistant" &&
nextMessage.parentNodeId !== null &&
!addedSiblingGroups.has(nextMessage.parentNodeId)
) {
// Get all sibling assistant messages that share the same parent
const siblingMessages = siblingAssistantMap.get(
nextMessage.parentNodeId
);
if (siblingMessages && siblingMessages.length > 1) {
// Multiple siblings - add ALL of them for tabbed display
// Sort by nodeId to ensure consistent ordering
const sortedSiblings = [...siblingMessages].sort(
(a, b) => a.nodeId - b.nodeId
);
for (const msg of sortedSiblings) {
chain.push(msg);
}
addedSiblingGroups.add(nextMessage.parentNodeId);
// Continue from the one that was "latest" (the one we were following)
currentMessage = nextMessage;
} else {
// Only one assistant child, add normally
chain.push(nextMessage);
addedSiblingGroups.add(nextMessage.parentNodeId);
currentMessage = nextMessage;
}
} else if (
nextMessage.type === "assistant" &&
nextMessage.parentNodeId !== null &&
addedSiblingGroups.has(nextMessage.parentNodeId)
) {
// Already added this sibling group, just move to next
currentMessage = nextMessage;
} else {
// Normal message (user message or no parent)
chain.push(nextMessage);
currentMessage = nextMessage;
}
} else {
console.warn(
`Chain broken: Message with nodeId ${nextNodeId} not found.`
@@ -414,6 +466,8 @@ interface BuildEmptyMessageParams {
message?: string;
files?: FileDescriptor[];
nodeIdOffset?: number;
modelProvider?: string;
modelName?: string;
}
export const buildEmptyMessage = (params: BuildEmptyMessageParams): Message => {
@@ -427,6 +481,8 @@ export const buildEmptyMessage = (params: BuildEmptyMessageParams): Message => {
toolCall: null,
parentNodeId: params.parentNodeId,
packets: [],
modelProvider: params.modelProvider,
modelName: params.modelName,
};
};

View File

@@ -226,6 +226,7 @@ export type ObjTypes =
export interface Packet {
turn_index: number;
obj: ObjTypes;
model_id?: string | null; // Format: "{provider}:{model_name}" e.g. "openai:gpt-4"
}
export interface ChatPacket {

View File

@@ -484,7 +484,12 @@ export interface LlmDescriptor {
}
export interface LlmManager {
// Multi-model selection (array of 1-4 models)
selectedLlms: LlmDescriptor[];
updateSelectedLlms: (llms: LlmDescriptor[]) => void;
// Convenience getter for backwards compatibility - returns first selected model
currentLlm: LlmDescriptor;
// Legacy single-model update - updates first model in selection
updateCurrentLlm: (newOverride: LlmDescriptor) => void;
temperature: number;
updateTemperature: (temperature: number) => void;
@@ -564,11 +569,20 @@ export function useLlmManager(
const [userHasManuallyOverriddenLLM, setUserHasManuallyOverriddenLLM] =
useState(false);
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
const [currentLlm, setCurrentLlm] = useState<LlmDescriptor>({
name: "",
provider: "",
modelName: "",
});
const [selectedLlms, setSelectedLlms] = useState<LlmDescriptor[]>([
{ name: "", provider: "", modelName: "" },
]);
// Convenience getter for backwards compatibility
const currentLlm = useMemo(
() => selectedLlms[0] || { name: "", provider: "", modelName: "" },
[selectedLlms]
);
// Helper to set the primary (first) model while preserving others
const setPrimaryLlm = (llm: LlmDescriptor) => {
setSelectedLlms((prev) => [llm, ...prev.slice(1)]);
};
const llmUpdate = () => {
/* Should be called when the live assistant or current chat session changes */
@@ -588,11 +602,11 @@ export function useLlmManager(
}
if (currentChatSession?.current_alternate_model) {
setCurrentLlm(
setPrimaryLlm(
getValidLlmDescriptor(currentChatSession.current_alternate_model)
);
} else if (liveAssistant?.llm_model_version_override) {
setCurrentLlm(
setPrimaryLlm(
getValidLlmDescriptor(liveAssistant.llm_model_version_override)
);
} else if (userHasManuallyOverriddenLLM) {
@@ -600,14 +614,14 @@ export function useLlmManager(
// current chat session, use the override
return;
} else if (user?.preferences?.default_model) {
setCurrentLlm(getValidLlmDescriptor(user.preferences.default_model));
setPrimaryLlm(getValidLlmDescriptor(user.preferences.default_model));
} else {
const defaultProvider = llmProviders.find(
(provider) => provider.is_default_provider
);
if (defaultProvider) {
setCurrentLlm({
setPrimaryLlm({
name: defaultProvider.name,
provider: defaultProvider.provider,
modelName: defaultProvider.default_model_name,
@@ -665,20 +679,30 @@ export function useLlmManager(
setImageFilesPresent(present);
};
// Manually set the LLM
// Update all selected LLMs (multi-select)
const updateSelectedLlms = (llms: LlmDescriptor[]) => {
// Ensure at least one model is selected
if (llms.length === 0) {
return;
}
setSelectedLlms(llms);
setUserHasManuallyOverriddenLLM(true);
};
// Legacy: Manually set the primary LLM (backwards compatibility)
const updateCurrentLlm = (newLlm: LlmDescriptor) => {
setCurrentLlm(newLlm);
setPrimaryLlm(newLlm);
setUserHasManuallyOverriddenLLM(true);
};
const updateCurrentLlmToModelName = (modelName: string) => {
setCurrentLlm(getValidLlmDescriptor(modelName));
setPrimaryLlm(getValidLlmDescriptor(modelName));
setUserHasManuallyOverriddenLLM(true);
};
const updateModelOverrideBasedOnChatSession = (chatSession?: ChatSession) => {
if (chatSession && chatSession.current_alternate_model?.length > 0) {
setCurrentLlm(getValidLlmDescriptor(chatSession.current_alternate_model));
setPrimaryLlm(getValidLlmDescriptor(chatSession.current_alternate_model));
}
};
@@ -757,6 +781,10 @@ export function useLlmManager(
return {
updateModelOverrideBasedOnChatSession,
// Multi-model selection
selectedLlms,
updateSelectedLlms,
// Backwards compatibility (single model)
currentLlm,
updateCurrentLlm,
temperature,

View File

@@ -53,6 +53,11 @@ export interface LLMPopoverProps {
onSelect?: (value: string) => void;
currentModelName?: string;
disabled?: boolean;
maxSelection?: number;
minSelection?: number;
// Single-select mode: for one-shot actions like regeneration
// Selects model, calls onSelect, and closes popover (doesn't update llmManager.selectedLlms)
singleSelectMode?: boolean;
}
export default function LLMPopover({
@@ -61,14 +66,35 @@ export default function LLMPopover({
onSelect,
currentModelName,
disabled = false,
maxSelection = 4,
minSelection = 1,
singleSelectMode = false,
}: LLMPopoverProps) {
const llmProviders = llmManager.llmProviders;
const isLoadingProviders = llmManager.isLoadingProviders;
// Use selectedLlms from llmManager directly
const selectedModels = llmManager.selectedLlms;
const onSelectionChange = llmManager.updateSelectedLlms;
const [open, setOpen] = useState(false);
const [searchQuery, setSearchQuery] = useState("");
const { user } = useUser();
// Helper to check if a model is selected
const isModelSelected = useCallback(
(option: LLMOption) => {
return selectedModels.some(
(m) =>
m.modelName === option.modelName && m.provider === option.provider
);
},
[selectedModels]
);
// Check if we're at max selection
const isAtMaxSelection = selectedModels.length >= maxSelection;
const [localTemperature, setLocalTemperature] = useState(
llmManager.temperature ?? 0.5
);
@@ -245,6 +271,35 @@ export default function LLMPopover({
return currentModel;
}, [llmProviders, llmManager.currentLlm.modelName]);
// Get display name for a model (used in multi-select mode)
const getModelDisplayName = useCallback(
(model: LlmDescriptor) => {
if (!llmProviders) return model.modelName;
for (const provider of llmProviders) {
const config = provider.model_configurations.find(
(m) => m.name === model.modelName
);
if (config) {
return config.display_name || config.name;
}
}
return model.modelName;
},
[llmProviders]
);
// Get trigger display text for multi-select mode
const multiSelectTriggerText = useMemo(() => {
if (selectedModels.length === 0) {
return "Select models";
}
if (selectedModels.length === 1) {
return getModelDisplayName(selectedModels[0]!);
}
return `${selectedModels.length} models selected`;
}, [selectedModels, getModelDisplayName]);
// Determine which group the current model belongs to (for auto-expand)
const currentGroupKey = useMemo(() => {
const currentModel = llmManager.currentLlm.modelName;
@@ -318,13 +373,55 @@ export default function LLMPopover({
};
const handleSelectModel = (option: LLMOption) => {
llmManager.updateCurrentLlm({
const newDescriptor: LlmDescriptor = {
modelName: option.modelName,
provider: option.provider,
name: option.name,
} as LlmDescriptor);
} as LlmDescriptor;
if (singleSelectMode) {
// Single-select mode: just call onSelect and close (for actions like regeneration)
onSelect?.(
structureValue(option.name, option.provider, option.modelName)
);
setOpen(false);
return;
}
// Multi-select mode: toggle selection
const isSelected = isModelSelected(option);
if (isSelected) {
// Don't allow deselecting if at minimum
if (selectedModels.length <= minSelection) {
return;
}
const newSelection = selectedModels.filter(
(m) =>
!(m.modelName === option.modelName && m.provider === option.provider)
);
onSelectionChange(newSelection);
} else {
// Don't allow selecting if at maximum
if (selectedModels.length >= maxSelection) {
return;
}
onSelectionChange([...selectedModels, newDescriptor]);
}
// Also call legacy onSelect for backwards compatibility
onSelect?.(structureValue(option.name, option.provider, option.modelName));
setOpen(false);
};
// Handle removing a model from selection (for chips)
const handleRemoveModel = (model: LlmDescriptor) => {
if (selectedModels.length <= minSelection) {
return;
}
const newSelection = selectedModels.filter(
(m) => !(m.modelName === model.modelName && m.provider === model.provider)
);
onSelectionChange?.(newSelection);
};
return (
@@ -335,10 +432,15 @@ export default function LLMPopover({
leftIcon={
folded
? SvgRefreshCw
: getProviderIcon(
llmManager.currentLlm.provider,
llmManager.currentLlm.modelName
)
: selectedModels.length > 0
? getProviderIcon(
selectedModels[0]!.provider,
selectedModels[0]!.modelName
)
: getProviderIcon(
llmManager.currentLlm.provider,
llmManager.currentLlm.modelName
)
}
onClick={() => setOpen(true)}
transient={open}
@@ -347,12 +449,79 @@ export default function LLMPopover({
disabled={disabled}
className={disabled ? "bg-transparent" : ""}
>
{currentLlmDisplayName}
{multiSelectTriggerText}
</SelectButton>
</div>
</PopoverTrigger>
<PopoverContent side="top" align="end" className="w-[280px] p-1">
<div className="flex flex-col gap-1">
{/* Selection Summary (hidden in single-select mode) */}
{!singleSelectMode && selectedModels.length > 0 && (
<div className="px-2 py-2 border-b border-border-02">
<div className="flex items-center justify-between mb-2">
<Text secondaryBody text03>
Selected ({selectedModels.length}/{maxSelection})
</Text>
{isAtMaxSelection && (
<Text
secondaryBody
className="text-amber-600 dark:text-amber-400 text-xs"
>
Max reached
</Text>
)}
</div>
{selectedModels.length === 0 ? (
<Text secondaryBody text03 className="text-text-04">
Select up to {maxSelection} models
</Text>
) : (
<div className="flex flex-wrap gap-1">
{selectedModels.map((model) => (
<div
key={`${model.provider}-${model.modelName}`}
className="inline-flex items-center gap-1 px-2 py-1 bg-action-link-01 text-text-inverse rounded-full text-xs"
>
<span className="max-w-[100px] truncate">
{getModelDisplayName(model)}
</span>
{selectedModels.length > minSelection && (
<button
onClick={(e) => {
e.stopPropagation();
handleRemoveModel(model);
}}
className="ml-0.5 hover:bg-white/20 rounded-full p-0.5 transition-colors"
title="Remove model"
>
<svg
className="w-3 h-3"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
)}
</div>
))}
</div>
)}
{selectedModels.length > 0 &&
selectedModels.length < minSelection && (
<Text secondaryBody className="text-text-04 text-xs mt-1">
At least {minSelection} model required
</Text>
)}
</div>
)}
{/* Search Input */}
<InputTypeIn
ref={searchInputRef}
@@ -439,11 +608,18 @@ export default function LLMPopover({
<div className="flex flex-col gap-1">
{group.options.map((option) => {
// Match by both modelName AND provider to handle same model name across providers
const isSelected =
option.modelName ===
llmManager.currentLlm.modelName &&
option.provider ===
llmManager.currentLlm.provider;
const isSelected = isModelSelected(option);
// In single-select mode, check if this is the current model being replaced
const isCurrentModel =
singleSelectMode &&
currentModelName === option.modelName;
// Disable unselected items when at max (only in multi-select mode)
const isDisabled =
!singleSelectMode &&
!isSelected &&
isAtMaxSelection;
// Build description with version info
const description =
@@ -456,18 +632,42 @@ export default function LLMPopover({
<div
key={`${option.name}-${option.modelName}`}
ref={
isSelected ? selectedItemRef : undefined
isSelected || isCurrentModel
? selectedItemRef
: undefined
}
title={
isDisabled
? `Maximum of ${maxSelection} models selected`
: undefined
}
>
<LineItem
selected={isSelected}
selected={
singleSelectMode
? isCurrentModel
: isSelected
}
description={description}
onClick={() =>
!isDisabled &&
handleSelectModel(option)
}
className="pl-7"
className={`pl-7 ${
isDisabled
? "opacity-50 cursor-not-allowed"
: ""
}`}
rightChildren={
isSelected ? (
// In single-select mode: show "Current" badge for the model being replaced
// In multi-select mode: show checkmark for selected models
singleSelectMode ? (
isCurrentModel ? (
<span className="text-[10px] font-medium text-text-03 bg-background-emphasis px-1.5 py-0.5 rounded">
Current
</span>
) : null
) : isSelected ? (
<SvgCheck className="h-4 w-4 stroke-action-link-05 shrink-0" />
) : null
}