Compare commits

...

5 Commits

28 changed files with 2873 additions and 97 deletions

View File

@@ -0,0 +1,36 @@
"""add preferred_response_id and model_display_name to chat_message
Revision ID: a3f8b2c1d4e5
Create Date: 2026-03-22
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a3f8b2c1d4e5"
down_revision = "b728689f45b1"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column(
"preferred_response_id",
sa.Integer(),
sa.ForeignKey("chat_message.id"),
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column("model_display_name", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "model_display_name")
op.drop_column("chat_message", "preferred_response_id")

View File

@@ -8,6 +8,7 @@ from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import InMemoryChatFile
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import Packet
@@ -35,7 +36,13 @@ class CreateChatSessionID(BaseModel):
chat_session_id: UUID
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
AnswerStreamPart = (
Packet
| MessageResponseIDInfo
| MultiModelMessageResponseIDInfo
| StreamingError
| CreateChatSessionID
)
AnswerStream = Iterator[AnswerStreamPart]

View File

@@ -4,9 +4,11 @@ An overview can be found in the README.md file in this directory.
"""
import io
import queue
import re
import traceback
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextvars import Token
from uuid import UUID
@@ -28,6 +30,7 @@ from onyx.chat.compression import calculate_total_history_tokens
from onyx.chat.compression import compress_chat_history
from onyx.chat.compression import find_summary_for_branch
from onyx.chat.compression import get_compression_params
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.chat.llm_loop import EmptyLLMResponseError
from onyx.chat.llm_loop import run_llm_loop
@@ -59,6 +62,8 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import reserve_multi_model_message_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -77,16 +82,20 @@ from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.override_models import LLMOverride
from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
from onyx.tools.constants import SEARCH_TOOL_ID
@@ -997,6 +1006,568 @@ def handle_stream_message_objects(
logger.exception("Error in setting processing status")
def _build_model_display_name(override: LLMOverride) -> str:
"""Build a human-readable display name from an LLM override."""
if override.display_name:
return override.display_name
if override.model_version:
return override.model_version
if override.model_provider:
return override.model_provider
return "unknown"
# Sentinel placed on the merged queue when a model thread finishes.
_MODEL_DONE = object()
class _ModelIndexEmitter(Emitter):
"""Emitter that tags packets with model_index and forwards directly to a shared queue.
Unlike the standard Emitter (which accumulates in a local bus), this puts
packets into the shared merged_queue in real-time as they're emitted. This
enables true parallel streaming — packets from multiple models interleave
on the wire instead of arriving in bursts after each model completes.
"""
def __init__(self, model_idx: int, merged_queue: queue.Queue) -> None:
super().__init__(queue.Queue()) # bus exists for compat, unused
self._model_idx = model_idx
self._merged_queue = merged_queue
def emit(self, packet: Packet) -> None:
tagged_placement = Placement(
turn_index=packet.placement.turn_index if packet.placement else 0,
tab_index=packet.placement.tab_index if packet.placement else 0,
sub_turn_index=(
packet.placement.sub_turn_index if packet.placement else None
),
model_index=self._model_idx,
)
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
self._merged_queue.put((self._model_idx, tagged_packet))
def run_multi_model_stream(
new_msg_req: SendMessageRequest,
user: User,
db_session: Session,
llm_overrides: list[LLMOverride],
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
mcp_headers: dict[str, str] | None = None,
) -> AnswerStream:
# TODO: The setup logic below (session resolution through tool construction)
# is duplicated from handle_stream_message_objects. Extract into a shared
# _ChatStreamContext dataclass + _prepare_chat_stream_context() factory so
# both paths call the same setup code. Tracked as follow-up refactor.
"""Run 2-3 LLMs in parallel and yield their packets tagged with model_index.
Resource management:
- Each model thread gets its OWN db_session (SQLAlchemy sessions are not thread-safe)
- The caller's db_session is used only for setup (before threads launch) and
completion callbacks (after threads finish)
- ThreadPoolExecutor is bounded to len(overrides) workers
- All threads are joined in the finally block regardless of success/failure
- Queue-based merging avoids busy-waiting
"""
n_models = len(llm_overrides)
if n_models < 2 or n_models > 3:
raise ValueError(f"Multi-model requires 2-3 overrides, got {n_models}")
if new_msg_req.deep_research:
raise ValueError("Multi-model is not supported with deep research")
tenant_id = get_current_tenant_id()
cache: CacheBackend | None = None
chat_session: ChatSession | None = None
user_id = user.id
if user.is_anonymous:
llm_user_identifier = "anonymous_user"
else:
llm_user_identifier = user.email or str(user_id)
try:
# ── Session setup (same as single-model path) ──────────────────
if not new_msg_req.chat_session_id:
if not new_msg_req.chat_session_info:
raise RuntimeError(
"Must specify a chat session id or chat session info"
)
chat_session = create_chat_session_from_request(
chat_session_request=new_msg_req.chat_session_info,
user_id=user_id,
db_session=db_session,
)
yield CreateChatSessionID(chat_session_id=chat_session.id)
else:
chat_session = get_chat_session_by_id(
chat_session_id=new_msg_req.chat_session_id,
user_id=user_id,
db_session=db_session,
)
persona = chat_session.persona
message_text = new_msg_req.message
# ── Build N LLM instances and validate costs ───────────────────
llms: list[LLM] = []
model_display_names: list[str] = []
for override in llm_overrides:
llm = get_llm_for_persona(
persona=persona,
user=user,
llm_override=override,
additional_headers=litellm_additional_headers,
)
check_llm_cost_limit_for_provider(
db_session=db_session,
tenant_id=tenant_id,
llm_provider_api_key=llm.config.api_key,
)
llms.append(llm)
model_display_names.append(_build_model_display_name(override))
# Use first LLM for token counting (context window is checked per-model
# but token counting is model-agnostic enough for setup purposes)
token_counter = get_llm_token_counter(llms[0])
verify_user_files(
user_files=new_msg_req.file_descriptors,
user_id=user_id,
db_session=db_session,
project_id=chat_session.project_id,
)
# ── Chat history chain (shared across all models) ──────────────
chat_history = create_chat_history_chain(
chat_session_id=chat_session.id, db_session=db_session
)
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
parent_message = chat_history[-1] if chat_history else root_message
elif (
new_msg_req.parent_message_id is None
or new_msg_req.parent_message_id == root_message.id
):
parent_message = root_message
chat_history = []
else:
parent_message = None
for i in range(len(chat_history) - 1, -1, -1):
if chat_history[i].id == new_msg_req.parent_message_id:
parent_message = chat_history[i]
chat_history = chat_history[: i + 1]
break
if parent_message is None:
raise ValueError(
"The new message sent is not on the latest mainline of messages"
)
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
message=message_text,
token_count=token_counter(message_text),
message_type=MessageType.USER,
files=new_msg_req.file_descriptors,
db_session=db_session,
commit=True,
)
chat_history.append(user_message)
available_files = _collect_available_file_ids(
chat_history=chat_history,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
summary_message = find_summary_for_branch(db_session, chat_history)
summarized_file_metadata: dict[str, FileToolMetadata] = {}
if summary_message and summary_message.last_summarized_message_id:
cutoff_id = summary_message.last_summarized_message_id
for msg in chat_history:
if msg.id > cutoff_id or not msg.files:
continue
for fd in msg.files:
file_id = fd.get("id")
if not file_id:
continue
summarized_file_metadata[file_id] = FileToolMetadata(
file_id=file_id,
filename=fd.get("name") or "unknown",
approx_char_count=0,
)
chat_history = [m for m in chat_history if m.id > cutoff_id]
user_memory_context = get_memories(user, db_session)
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
prompt_memory_context = (
user_memory_context
if user.use_memories
else user_memory_context.without_memories()
)
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
custom_agent_prompt or ""
)
reserved_token_count = calculate_reserved_tokens(
db_session=db_session,
persona_system_prompt=max_reserved_system_prompt_tokens_str,
token_counter=token_counter,
files=new_msg_req.file_descriptors,
user_memory_context=prompt_memory_context,
)
context_user_files = resolve_context_user_files(
persona=persona,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
# Use the smallest context window across all models for safety
min_context_window = min(llm.config.max_input_tokens for llm in llms)
extracted_context_files = extract_context_files(
user_files=context_user_files,
llm_max_context_window=min_context_window,
reserved_token_count=reserved_token_count,
db_session=db_session,
)
search_params = determine_search_params(
persona_id=persona.id,
project_id=chat_session.project_id,
extracted_context_files=extracted_context_files,
)
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
all_tools = get_tools(db_session)
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
search_tool_id = next(
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
None,
)
forced_tool_id = new_msg_req.forced_tool_id
if (
search_params.search_usage == SearchToolUsage.DISABLED
and forced_tool_id is not None
and search_tool_id is not None
and forced_tool_id == search_tool_id
):
forced_tool_id = None
files = load_all_chat_files(chat_history, db_session)
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
# ── Reserve N assistant message IDs ────────────────────────────
reserved_messages = reserve_multi_model_message_ids(
db_session=db_session,
chat_session_id=chat_session.id,
parent_message_id=user_message.id,
model_display_names=model_display_names,
)
yield MultiModelMessageResponseIDInfo(
user_message_id=user_message.id,
reserved_assistant_message_ids=[m.id for m in reserved_messages],
model_names=model_display_names,
)
has_file_reader_tool = any(
tool.in_code_tool_id == "file_reader" for tool in all_tools
)
chat_history_result = convert_chat_history(
chat_history=chat_history,
files=files,
context_image_files=extracted_context_files.image_files,
additional_context=new_msg_req.additional_context,
token_counter=token_counter,
tool_id_to_name_map=tool_id_to_name_map,
)
simple_chat_history = chat_history_result.simple_messages
all_injected_file_metadata: dict[str, FileToolMetadata] = (
chat_history_result.all_injected_file_metadata
if has_file_reader_tool
else {}
)
if summarized_file_metadata:
for fid, meta in summarized_file_metadata.items():
all_injected_file_metadata.setdefault(fid, meta)
if summary_message is not None:
summary_simple = ChatMessageSimple(
message=summary_message.message,
token_count=summary_message.token_count,
message_type=MessageType.ASSISTANT,
)
simple_chat_history.insert(0, summary_simple)
# ── Stop signal and processing status ──────────────────────────
cache = get_cache_backend()
reset_cancel_status(chat_session.id, cache)
def check_is_connected() -> bool:
return check_stop_signal(chat_session.id, cache)
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
value=True,
)
# Release the main session's read transaction before the long stream
db_session.commit()
# ── Parallel model execution ───────────────────────────────────
# Each model thread writes tagged packets to this shared queue.
# Sentinel _MODEL_DONE signals that a thread finished.
merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = (
queue.Queue()
)
# Track per-model state containers for completion callbacks
state_containers: list[ChatStateContainer] = [
ChatStateContainer() for _ in range(n_models)
]
# Track which models completed successfully (for completion callbacks)
model_succeeded: list[bool] = [False] * n_models
user_identity = LLMUserIdentity(
user_id=llm_user_identifier,
session_id=str(chat_session.id),
)
def _run_model(model_idx: int) -> None:
"""Run a single model in a worker thread.
Uses _ModelIndexEmitter so packets flow directly to merged_queue
in real-time (not batched after completion). This enables true
parallel streaming where both models' tokens interleave on the wire.
DB access: tools may need a session during execution (e.g., search
tool). Each thread creates its own session via context manager.
"""
model_emitter = _ModelIndexEmitter(model_idx, merged_queue)
sc = state_containers[model_idx]
model_llm = llms[model_idx]
try:
# Each model thread gets its own DB session for tool execution.
# The session is scoped to the thread and closed when done.
with get_session_with_current_tenant() as thread_db_session:
# Construct tools per-thread with thread-local DB session
thread_tool_dict = construct_tools(
persona=persona,
db_session=thread_db_session,
emitter=model_emitter,
user=user,
llm=model_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=new_msg_req.internal_search_filters,
project_id_filter=search_params.project_id_filter,
persona_id_filter=search_params.persona_id_filter,
bypass_acl=False,
enable_slack_search=_should_enable_slack_search(
persona, new_msg_req.internal_search_filters
),
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session.id,
message_id=user_message.id,
additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=available_files.user_file_ids,
chat_file_ids=available_files.chat_file_ids,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=search_params.search_usage,
)
model_tools: list[Tool] = []
for tool_list in thread_tool_dict.values():
model_tools.extend(tool_list)
# Run the LLM loop — this blocks until the model finishes.
# Packets flow to merged_queue in real-time via the emitter.
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=simple_chat_history,
tools=model_tools,
custom_agent_prompt=custom_agent_prompt,
context_files=extracted_context_files,
persona=persona,
user_memory_context=user_memory_context,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
forced_tool_id=forced_tool_id,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
chat_files=chat_files_for_tools,
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
model_succeeded[model_idx] = True
except Exception as e:
merged_queue.put((model_idx, e))
finally:
merged_queue.put((model_idx, _MODEL_DONE))
# Launch model threads via ThreadPoolExecutor (bounded, context-propagating)
executor = ThreadPoolExecutor(
max_workers=n_models,
thread_name_prefix="multi-model",
)
futures = []
try:
for i in range(n_models):
futures.append(executor.submit(_run_model, i))
# ── Main thread: merge and yield packets ───────────────────
models_remaining = n_models
while models_remaining > 0:
try:
model_idx, item = merged_queue.get(timeout=0.3)
except queue.Empty:
# Check cancellation during idle periods
if not check_is_connected():
yield Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
return
continue
if item is _MODEL_DONE:
models_remaining -= 1
continue
if isinstance(item, Exception):
# Yield error as a tagged StreamingError packet
error_msg = str(item)
stack_trace = "".join(
traceback.format_exception(type(item), item, item.__traceback__)
)
# Redact API keys from error messages
model_llm = llms[model_idx]
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
error_msg = error_msg.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
stack_trace = stack_trace.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(
error=error_msg,
stack_trace=stack_trace,
error_code="MODEL_ERROR",
is_retryable=True,
details={
"model": model_llm.config.model_name,
"provider": model_llm.config.model_provider,
"model_index": model_idx,
},
)
models_remaining -= 1
continue
if isinstance(item, Packet):
# Packet is already tagged with model_index by _ModelIndexEmitter
yield item
# ── Completion: save each successful model's response ──────
# Run completion callbacks on the main thread using the main
# session. This is safe because all worker threads have exited
# by this point (merged_queue fully drained).
for i in range(n_models):
if not model_succeeded[i]:
continue
try:
llm_loop_completion_handle(
state_container=state_containers[i],
is_connected=check_is_connected,
db_session=db_session,
assistant_message=reserved_messages[i],
llm=llms[i],
reserved_tokens=reserved_token_count,
)
except Exception:
logger.exception(
f"Failed completion for model {i} "
f"({model_display_names[i]})"
)
yield Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop", stop_reason="complete"),
)
finally:
# Ensure all threads are cleaned up regardless of how we exit
executor.shutdown(wait=True, cancel_futures=True)
except ValueError as e:
logger.exception("Failed to process multi-model chat message.")
yield StreamingError(
error=str(e),
error_code="VALIDATION_ERROR",
is_retryable=True,
)
db_session.rollback()
return
except Exception as e:
logger.exception(f"Failed multi-model chat: {e}")
stack_trace = traceback.format_exc()
yield StreamingError(
error=str(e),
stack_trace=stack_trace,
error_code="MULTI_MODEL_ERROR",
is_retryable=True,
)
db_session.rollback()
finally:
try:
if cache is not None and chat_session is not None:
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
value=False,
)
except Exception:
logger.exception("Error clearing processing status")
def llm_loop_completion_handle(
state_container: ChatStateContainer,
is_connected: Callable[[], bool],

View File

@@ -602,6 +602,79 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15,
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Set the preferred assistant response for a multi-model user message.
Validates that the user message is a USER type and that the preferred
assistant message is a direct child of that user message.
"""
user_msg = db_session.query(ChatMessage).get(user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.query(ChatMessage).get(preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -824,6 +897,8 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -2645,6 +2645,15 @@ class ChatMessage(Base):
nullable=True,
)
# For multi-model turns: the user message points to which assistant response
# was selected as the preferred one to continue the conversation with.
preferred_response_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id"), nullable=True
)
# The display name of the model that generated this assistant message
model_display_name: Mapped[str | None] = mapped_column(String, nullable=True)
# What does this message contain
reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True)
message: Mapped[str] = mapped_column(Text)
@@ -2712,6 +2721,12 @@ class ChatMessage(Base):
remote_side="ChatMessage.id",
)
preferred_response: Mapped["ChatMessage | None"] = relationship(
"ChatMessage",
foreign_keys=[preferred_response_id],
remote_side="ChatMessage.id",
)
# Chat messages only need to know their immediate tool call children
# If there are nested tool calls, they are stored in the tool_call_children relationship.
tool_calls: Mapped[list["ToolCall"] | None] = relationship(

View File

@@ -11,6 +11,7 @@ class LLMOverride(BaseModel):
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None
display_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

View File

@@ -29,6 +29,7 @@ from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.process_message import run_multi_model_stream
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
from onyx.configs.app_configs import WEB_DOMAIN
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -81,6 +83,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -570,6 +573,38 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in run_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=chat_message_req.llm_overrides, # type: ignore[arg-type]
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -660,6 +695,26 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
_user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
try:
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -41,6 +41,16 @@ class MessageResponseIDInfo(BaseModel):
reserved_assistant_message_id: int
class MultiModelMessageResponseIDInfo(BaseModel):
"""Sent at the start of a multi-model streaming response.
Contains the user message ID and the reserved assistant message IDs
for each model being run in parallel."""
user_message_id: int | None
reserved_assistant_message_ids: list[int]
model_names: list[str]
class SourceTag(Tag):
source: DocumentSource
@@ -86,6 +96,9 @@ class SendMessageRequest(BaseModel):
message: str
llm_override: LLMOverride | None = None
# For multi-model mode: up to 3 LLM overrides to run in parallel.
# When provided with >1 entry, triggers multi-model streaming.
llm_overrides: list[LLMOverride] | None = None
# Test-only override for deterministic LiteLLM mock responses.
mock_llm_response: str | None = None
@@ -211,6 +224,8 @@ class ChatMessageDetail(BaseModel):
error: str | None = None
current_feedback: str | None = None # "like" | "dislike" | null
processing_duration_seconds: float | None = None
preferred_response_id: int | None = None
model_display_name: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
@@ -218,6 +233,11 @@ class ChatMessageDetail(BaseModel):
return initial_dict
class SetPreferredResponseRequest(BaseModel):
user_message_id: int
preferred_response_id: int
class ChatSessionDetailResponse(BaseModel):
chat_session_id: UUID
description: str | None

View File

@@ -8,3 +8,5 @@ class Placement(BaseModel):
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -0,0 +1,206 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in run_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
from typing import Any
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
"""Advance the generator one step to trigger early validation."""
from onyx.chat.process_message import run_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = run_multi_model_stream(req, user, db, overrides)
# Calling next() executes until the first yield OR raises.
# Validation errors are raised before any yield.
next(gen)
# ---------------------------------------------------------------------------
# run_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_raises(self) -> None:
"""Exactly 1 override is not multi-model — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
def test_four_overrides_raises(self) -> None:
"""4 overrides exceeds maximum — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
def test_zero_overrides_raises(self) -> None:
"""Empty override list raises."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [])
def test_deep_research_raises(self) -> None:
"""deep_research=True is incompatible with multi-model."""
req = _make_request(deep_research=True)
with pytest.raises(ValueError, match="not supported"):
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
req = _make_request()
# 1 override must fail
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
try:
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
except ValueError as exc:
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
except Exception:
pass # Any other error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.query.return_value.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.query.return_value.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.query.return_value.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"

View File

@@ -0,0 +1,134 @@
"""Unit tests for multi-model answer generation types.
Tests cover:
- Placement.model_index serialization
- MultiModelMessageResponseIDInfo round-trip
- SendMessageRequest.llm_overrides backward compatibility
- ChatMessageDetail new fields
"""
from uuid import uuid4
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
class TestPlacementModelIndex:
def test_default_none(self) -> None:
p = Placement(turn_index=0)
assert p.model_index is None
def test_set_value(self) -> None:
p = Placement(turn_index=0, model_index=2)
assert p.model_index == 2
def test_serializes(self) -> None:
p = Placement(turn_index=0, tab_index=1, model_index=1)
d = p.model_dump()
assert d["model_index"] == 1
def test_none_excluded_when_default(self) -> None:
p = Placement(turn_index=0)
d = p.model_dump()
assert d["model_index"] is None
class TestMultiModelMessageResponseIDInfo:
def test_round_trip(self) -> None:
info = MultiModelMessageResponseIDInfo(
user_message_id=42,
reserved_assistant_message_ids=[43, 44, 45],
model_names=["gpt-4", "claude-opus", "gemini-pro"],
)
d = info.model_dump()
restored = MultiModelMessageResponseIDInfo(**d)
assert restored.user_message_id == 42
assert restored.reserved_assistant_message_ids == [43, 44, 45]
assert restored.model_names == ["gpt-4", "claude-opus", "gemini-pro"]
def test_null_user_message_id(self) -> None:
info = MultiModelMessageResponseIDInfo(
user_message_id=None,
reserved_assistant_message_ids=[1, 2],
model_names=["a", "b"],
)
assert info.user_message_id is None
class TestSendMessageRequestOverrides:
def test_llm_overrides_default_none(self) -> None:
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
)
assert req.llm_overrides is None
def test_llm_overrides_accepts_list(self) -> None:
overrides = [
LLMOverride(model_provider="openai", model_version="gpt-4"),
LLMOverride(model_provider="anthropic", model_version="claude-opus"),
]
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
llm_overrides=overrides,
)
assert req.llm_overrides is not None
assert len(req.llm_overrides) == 2
def test_backward_compat_single_override(self) -> None:
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
llm_override=LLMOverride(model_provider="openai", model_version="gpt-4"),
)
assert req.llm_override is not None
assert req.llm_overrides is None
class TestChatMessageDetailMultiModel:
def test_defaults_none(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.ASSISTANT,
time_sent="2026-03-22T00:00:00Z",
files=[],
)
assert detail.preferred_response_id is None
assert detail.model_display_name is None
def test_set_values(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.USER,
time_sent="2026-03-22T00:00:00Z",
files=[],
preferred_response_id=42,
model_display_name="GPT-4",
)
assert detail.preferred_response_id == 42
assert detail.model_display_name == "GPT-4"
def test_serializes(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.ASSISTANT,
time_sent="2026-03-22T00:00:00Z",
files=[],
model_display_name="Claude Opus",
)
d = detail.model_dump()
assert d["model_display_name"] == "Claude Opus"
assert d["preferred_response_id"] is None

View File

@@ -11,15 +11,22 @@ import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
import { useState, useEffect } from "react";
import { useSettingsContext } from "@/providers/SettingsProvider";
import FrostedDiv from "@/refresh-components/FrostedDiv";
import { cn } from "@/lib/utils";
export interface WelcomeMessageProps {
agent?: MinimalPersonaSnapshot;
isDefaultAgent: boolean;
/** Optional right-aligned element rendered on the same row as the greeting (e.g. model selector). */
rightChildren?: React.ReactNode;
/** When true, the greeting/logo content is hidden (but space is preserved). Used at max models. */
hideTitle?: boolean;
}
export default function WelcomeMessage({
agent,
isDefaultAgent,
rightChildren,
hideTitle,
}: WelcomeMessageProps) {
const settings = useSettingsContext();
const enterpriseSettings = settings?.enterpriseSettings;
@@ -39,8 +46,10 @@ export default function WelcomeMessage({
if (isDefaultAgent) {
content = (
<div data-testid="onyx-logo" className="flex flex-row items-center gap-4">
<Logo folded size={32} />
<div data-testid="onyx-logo" className="flex flex-col items-start gap-2">
<div className="flex items-center justify-center size-9 p-0.5">
<Logo folded size={32} />
</div>
<Text as="p" headingH2>
{greeting}
</Text>
@@ -48,17 +57,15 @@ export default function WelcomeMessage({
);
} else if (agent) {
content = (
<>
<div
data-testid="agent-name-display"
className="flex flex-row items-center gap-3"
>
<AgentAvatar agent={agent} size={36} />
<Text as="p" headingH2>
{agent.name}
</Text>
</div>
</>
<div
data-testid="agent-name-display"
className="flex flex-col items-start gap-2"
>
<AgentAvatar agent={agent} size={36} />
<Text as="p" headingH2>
{agent.name}
</Text>
</div>
);
}
@@ -69,9 +76,24 @@ export default function WelcomeMessage({
return (
<FrostedDiv
data-testid="chat-intro"
className="flex flex-col items-center justify-center gap-3 w-full max-w-[var(--app-page-main-content-width)]"
wrapperClassName="w-full"
className="flex flex-col items-center justify-center gap-3 w-full max-w-[var(--app-page-main-content-width)] mx-auto"
>
{content}
{rightChildren ? (
<div className="flex items-end gap-2 w-full">
<div
className={cn(
"flex-1 min-w-0 min-h-[80px] px-2 py-1",
hideTitle && "invisible"
)}
>
{content}
</div>
<div className="shrink-0">{rightChildren}</div>
</div>
) : (
content
)}
</FrostedDiv>
);
}

View File

@@ -159,6 +159,10 @@ export interface Message {
overridden_model?: string;
stopReason?: StreamStopReason | null;
// Multi-model answer generation
preferredResponseId?: number | null;
modelDisplayName?: string | null;
// new gen
packets: Packet[];
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
@@ -231,6 +235,9 @@ export interface BackendMessage {
parentMessageId: number | null;
refined_answer_improvement: boolean | null;
is_agentic: boolean | null;
// Multi-model answer generation
preferred_response_id: number | null;
model_display_name: string | null;
}
export interface MessageResponseIDInfo {
@@ -238,6 +245,12 @@ export interface MessageResponseIDInfo {
reserved_assistant_message_id: number; // TODO: rename to agent — https://linear.app/onyx-app/issue/ENG-3766
}
export interface MultiModelMessageResponseIDInfo {
user_message_id: number | null;
reserved_assistant_message_ids: number[];
model_names: string[];
}
export interface UserKnowledgeFilePacket {
user_files: FileDescriptor[];
}

View File

@@ -0,0 +1,149 @@
"use client";
import { useCallback } from "react";
import { Button } from "@opal/components";
import { Hoverable } from "@opal/core";
import { SvgEyeClosed, SvgMoreHorizontal, SvgX } from "@opal/icons";
import Text from "@/refresh-components/texts/Text";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import AgentMessage, {
AgentMessageProps,
} from "@/app/app/message/messageComponents/AgentMessage";
import { cn } from "@/lib/utils";
export interface MultiModelPanelProps {
/** Index of this model in the selectedModels array (used for Hoverable group key) */
modelIndex: number;
/** Provider name for icon lookup */
provider: string;
/** Model name for icon lookup and display */
modelName: string;
/** Display-friendly model name */
displayName: string;
/** Whether this panel is the preferred/selected response */
isPreferred: boolean;
/** Whether this panel is currently hidden */
isHidden: boolean;
/** Whether this is a non-preferred panel in selection mode (pushed off-screen) */
isNonPreferredInSelection: boolean;
/** Callback when user clicks this panel to select as preferred */
onSelect: () => void;
/** Callback to hide/show this panel */
onToggleVisibility: () => void;
/** Props to pass through to AgentMessage */
agentMessageProps: AgentMessageProps;
}
export default function MultiModelPanel({
modelIndex,
provider,
modelName,
displayName,
isPreferred,
isHidden,
isNonPreferredInSelection,
onSelect,
onToggleVisibility,
agentMessageProps,
}: MultiModelPanelProps) {
const ProviderIcon = getProviderIcon(provider, modelName);
const handlePanelClick = useCallback(() => {
if (!isHidden) {
onSelect();
}
}, [isHidden, onSelect]);
// Hidden/collapsed panel — compact strip: icon + strikethrough name + eye icon
if (isHidden) {
return (
<div className="flex items-center gap-1.5 rounded-08 bg-background-tint-00 px-2 py-1 opacity-50">
<div className="flex items-center justify-center size-5 shrink-0">
<ProviderIcon size={16} />
</div>
<Text secondaryBody text02 nowrap className="line-through">
{displayName}
</Text>
<Button
prominence="tertiary"
icon={SvgEyeClosed}
size="2xs"
onClick={onToggleVisibility}
tooltip="Show response"
/>
</div>
);
}
const hoverGroup = `panel-${modelIndex}`;
return (
<Hoverable.Root group={hoverGroup}>
<div
className="flex flex-col min-w-0 gap-3 cursor-pointer"
onClick={handlePanelClick}
>
{/* Panel header */}
<div
className={cn(
"flex items-center gap-1.5 rounded-12 px-2 py-1",
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
)}
>
<div className="flex items-center justify-center size-5 shrink-0">
<ProviderIcon size={16} />
</div>
<Text mainUiAction text04 nowrap className="flex-1 min-w-0 truncate">
{displayName}
</Text>
{isPreferred && (
<Text secondaryBody nowrap className="text-action-link-05 shrink-0">
Preferred Response
</Text>
)}
<Button
prominence="tertiary"
icon={SvgMoreHorizontal}
size="2xs"
tooltip="More"
onClick={(e) => e.stopPropagation()}
/>
<Button
prominence="tertiary"
icon={SvgX}
size="2xs"
onClick={(e) => {
e.stopPropagation();
onToggleVisibility();
}}
tooltip="Hide response"
/>
</div>
{/* "Select This Response" hover affordance */}
{!isPreferred && !isNonPreferredInSelection && (
<Hoverable.Item group={hoverGroup} variant="opacity-on-hover">
<div className="flex justify-center pointer-events-none">
<div className="flex items-center h-6 bg-background-tint-00 rounded-08 px-1 shadow-sm">
<Text
secondaryBody
className="font-semibold text-text-03 px-1 whitespace-nowrap"
>
Select This Response
</Text>
</div>
</div>
</Hoverable.Item>
)}
{/* Response body */}
<div className={cn(isNonPreferredInSelection && "pointer-events-none")}>
<AgentMessage
{...agentMessageProps}
hideFooter={isNonPreferredInSelection}
/>
</div>
</div>
</Hoverable.Root>
);
}

View File

@@ -0,0 +1,229 @@
"use client";
import { useState, useCallback, useMemo } from "react";
import { Packet } from "@/app/app/services/streamingModels";
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
import { FeedbackType, Message } from "@/app/app/interfaces";
import { LlmManager } from "@/lib/hooks";
import { RegenerationFactory } from "@/app/app/message/messageComponents/AgentMessage";
import MultiModelPanel from "@/app/app/message/MultiModelPanel";
import { cn } from "@/lib/utils";
export interface MultiModelResponse {
modelIndex: number;
provider: string;
modelName: string;
displayName: string;
packets: Packet[];
packetCount: number;
nodeId: number;
messageId?: number;
isHighlighted?: boolean;
currentFeedback?: FeedbackType | null;
isGenerating?: boolean;
}
export interface MultiModelResponseViewProps {
responses: MultiModelResponse[];
chatState: FullChatState;
llmManager: LlmManager | null;
onRegenerate?: RegenerationFactory;
parentMessage?: Message | null;
otherMessagesCanSwitchTo?: number[];
onMessageSelection?: (nodeId: number) => void;
}
export default function MultiModelResponseView({
responses,
chatState,
llmManager,
onRegenerate,
parentMessage,
otherMessagesCanSwitchTo,
onMessageSelection,
}: MultiModelResponseViewProps) {
const [preferredIndex, setPreferredIndex] = useState<number | null>(null);
const [hiddenPanels, setHiddenPanels] = useState<Set<number>>(new Set());
const isGenerating = useMemo(
() => responses.some((r) => r.isGenerating),
[responses]
);
const visibleResponses = useMemo(
() => responses.filter((r) => !hiddenPanels.has(r.modelIndex)),
[responses, hiddenPanels]
);
const hiddenResponses = useMemo(
() => responses.filter((r) => hiddenPanels.has(r.modelIndex)),
[responses, hiddenPanels]
);
const toggleVisibility = useCallback(
(modelIndex: number) => {
setHiddenPanels((prev) => {
const next = new Set(prev);
if (next.has(modelIndex)) {
next.delete(modelIndex);
} else {
// Don't hide the last visible panel
const visibleCount = responses.length - next.size;
if (visibleCount <= 1) return prev;
next.add(modelIndex);
}
return next;
});
},
[responses.length]
);
const handleSelectPreferred = useCallback(
(modelIndex: number) => {
setPreferredIndex(modelIndex);
const response = responses[modelIndex];
if (!response) return;
// Sync with message tree — mark this response as the latest child
// so the next message chains from it.
if (onMessageSelection) {
onMessageSelection(response.nodeId);
}
},
[responses, onMessageSelection]
);
// Selection mode when preferred is set and not generating
const showSelectionMode =
preferredIndex !== null && !isGenerating && visibleResponses.length > 1;
// Build common panel props
const buildPanelProps = useCallback(
(response: MultiModelResponse, isNonPreferred: boolean) => ({
modelIndex: response.modelIndex,
provider: response.provider,
modelName: response.modelName,
displayName: response.displayName,
isPreferred: preferredIndex === response.modelIndex,
isHidden: false as const,
isNonPreferredInSelection: isNonPreferred,
onSelect: () => handleSelectPreferred(response.modelIndex),
onToggleVisibility: () => toggleVisibility(response.modelIndex),
agentMessageProps: {
rawPackets: response.packets,
packetCount: response.packetCount,
chatState,
nodeId: response.nodeId,
messageId: response.messageId,
currentFeedback: response.currentFeedback,
llmManager,
otherMessagesCanSwitchTo,
onMessageSelection,
onRegenerate,
parentMessage,
},
}),
[
preferredIndex,
handleSelectPreferred,
toggleVisibility,
chatState,
llmManager,
otherMessagesCanSwitchTo,
onMessageSelection,
onRegenerate,
parentMessage,
]
);
// Shared renderer for hidden panels (inline in the flex row)
const renderHiddenPanels = () =>
hiddenResponses.map((r) => (
<div key={r.modelIndex} className="w-[240px] shrink-0">
<MultiModelPanel
modelIndex={r.modelIndex}
provider={r.provider}
modelName={r.modelName}
displayName={r.displayName}
isPreferred={false}
isHidden
isNonPreferredInSelection={false}
onSelect={() => handleSelectPreferred(r.modelIndex)}
onToggleVisibility={() => toggleVisibility(r.modelIndex)}
agentMessageProps={buildPanelProps(r, false).agentMessageProps}
/>
</div>
));
if (showSelectionMode) {
// ── Selection Layout ──
// Preferred stays at normal chat width, centered.
// Non-preferred panels are pushed to the viewport edges and clip off-screen.
const preferredIdx = visibleResponses.findIndex(
(r) => r.modelIndex === preferredIndex
);
const preferred = visibleResponses[preferredIdx];
const leftPanels = visibleResponses.slice(0, preferredIdx);
const rightPanels = visibleResponses.slice(preferredIdx + 1);
// Non-preferred panel width and gap between panels
const PANEL_W = 400;
const GAP = 16;
return (
<div className="w-full relative overflow-hidden">
{/* Preferred — centered at normal chat width, in flow to set container height */}
{preferred && (
<div className="w-full max-w-[720px] min-w-[400px] mx-auto">
<MultiModelPanel {...buildPanelProps(preferred, false)} />
</div>
)}
{/* Non-preferred on the left — anchored to the left of the preferred panel */}
{leftPanels.map((r, i) => (
<div
key={r.modelIndex}
className="absolute top-0"
style={{
width: `${PANEL_W}px`,
// Right edge of this panel sits just left of the preferred panel
right: `calc(50% + var(--app-page-main-content-width) / 2 + ${
GAP + i * (PANEL_W + GAP)
}px)`,
}}
>
<MultiModelPanel {...buildPanelProps(r, true)} />
</div>
))}
{/* Non-preferred on the right — anchored to the right of the preferred panel */}
{rightPanels.map((r, i) => (
<div
key={r.modelIndex}
className="absolute top-0"
style={{
width: `${PANEL_W}px`,
// Left edge of this panel sits just right of the preferred panel
left: `calc(50% + var(--app-page-main-content-width) / 2 + ${
GAP + i * (PANEL_W + GAP)
}px)`,
}}
>
<MultiModelPanel {...buildPanelProps(r, true)} />
</div>
))}
</div>
);
}
// ── Generation Layout (equal panels) ──
return (
<div className="flex gap-6 items-start justify-center">
{visibleResponses.map((r) => (
<div key={r.modelIndex} className="flex-1 min-w-[400px] max-w-[720px]">
<MultiModelPanel {...buildPanelProps(r, false)} />
</div>
))}
{renderHiddenPanels()}
</div>
);
}

View File

@@ -49,6 +49,8 @@ export interface AgentMessageProps {
parentMessage?: Message | null;
// Duration in seconds for processing this message (agent messages only)
processingDurationSeconds?: number;
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
hideFooter?: boolean;
}
// TODO: Consider more robust comparisons:
@@ -76,7 +78,8 @@ function arePropsEqual(
prev.parentMessage?.messageId === next.parentMessage?.messageId &&
prev.llmManager?.isLoadingProviders ===
next.llmManager?.isLoadingProviders &&
prev.processingDurationSeconds === next.processingDurationSeconds
prev.processingDurationSeconds === next.processingDurationSeconds &&
prev.hideFooter === next.hideFooter
// Skip: chatState.regenerate, chatState.setPresentingDocument,
// most of llmManager, onMessageSelection (function/object props)
);
@@ -95,6 +98,7 @@ const AgentMessage = React.memo(function AgentMessage({
onRegenerate,
parentMessage,
processingDurationSeconds,
hideFooter,
}: AgentMessageProps) {
const markdownRef = useRef<HTMLDivElement>(null);
const finalAnswerRef = useRef<HTMLDivElement>(null);
@@ -326,7 +330,7 @@ const AgentMessage = React.memo(function AgentMessage({
</div>
{/* Feedback buttons - only show when streaming and rendering complete */}
{isComplete && (
{isComplete && !hideFooter && (
<MessageToolbar
nodeId={nodeId}
messageId={messageId}

View File

@@ -12,6 +12,7 @@ import {
FileChatDisplay,
Message,
MessageResponseIDInfo,
MultiModelMessageResponseIDInfo,
ResearchType,
RetrievalType,
StreamingError,
@@ -96,6 +97,7 @@ export type PacketType =
| FileChatDisplay
| StreamingError
| MessageResponseIDInfo
| MultiModelMessageResponseIDInfo
| StreamStopInfo
| UserKnowledgeFilePacket
| Packet;
@@ -109,6 +111,13 @@ export type MessageOrigin =
| "slackbot"
| "unknown";
export interface LLMOverride {
model_provider: string;
model_version: string;
temperature?: number;
display_name?: string;
}
export interface SendMessageParams {
message: string;
fileDescriptors?: FileDescriptor[];
@@ -124,6 +133,8 @@ export interface SendMessageParams {
modelProvider?: string;
modelVersion?: string;
temperature?: number;
// Multi-model: send multiple LLM overrides for parallel generation
llmOverrides?: LLMOverride[];
// Origin of the message for telemetry tracking
origin?: MessageOrigin;
// Additional context injected into the LLM call but not stored/shown in chat.
@@ -144,6 +155,7 @@ export async function* sendMessage({
modelProvider,
modelVersion,
temperature,
llmOverrides,
origin,
additionalContext,
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
@@ -165,6 +177,8 @@ export async function* sendMessage({
model_version: modelVersion,
}
: null,
// Multi-model: list of LLM overrides for parallel generation
llm_overrides: llmOverrides ?? null,
// Default to "unknown" for consistency with backend; callers should set explicitly
origin: origin ?? "unknown",
additional_context: additionalContext ?? null,
@@ -188,6 +202,20 @@ export async function* sendMessage({
yield* handleSSEStream<PacketType>(response, signal);
}
export async function setPreferredResponse(
userMessageId: number,
preferredResponseId: number
): Promise<Response> {
return fetch("/api/chat/set-preferred-response", {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
user_message_id: userMessageId,
preferred_response_id: preferredResponseId,
}),
});
}
export async function nameChatSession(chatSessionId: string) {
const response = await fetch("/api/chat/rename-chat-session", {
method: "PUT",
@@ -357,6 +385,9 @@ export function processRawChatHistory(
overridden_model: messageInfo.overridden_model,
packets: packetsForMessage || [],
currentFeedback: messageInfo.current_feedback as FeedbackType | null,
// Multi-model answer generation
preferredResponseId: messageInfo.preferred_response_id ?? null,
modelDisplayName: messageInfo.model_display_name ?? null,
};
messages.set(messageInfo.message_id, message);

View File

@@ -403,6 +403,7 @@ export interface Placement {
turn_index: number;
tab_index?: number; // For parallel tool calls - tools with same turn_index but different tab_index run in parallel
sub_turn_index?: number | null;
model_index?: number | null; // For multi-model answer generation - identifies which model produced this packet
}
// Packet wrapper for streaming objects

View File

@@ -459,6 +459,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
onResubmit={handleResubmitLastMessage}
deepResearchEnabled={deepResearchEnabled}
anchorNodeId={anchorNodeId}
selectedModels={[]}
/>
</ChatScrollContainer>
</>

View File

@@ -3,6 +3,7 @@
import {
buildChatUrl,
getAvailableContextTokens,
LLMOverride,
nameChatSession,
updateLlmOverrideForChatSession,
} from "@/app/app/services/lib";
@@ -33,6 +34,7 @@ import {
FileDescriptor,
Message,
MessageResponseIDInfo,
MultiModelMessageResponseIDInfo,
RegenerationState,
RetrievalType,
StreamingError,
@@ -70,6 +72,7 @@ import {
} from "@/app/app/stores/useChatSessionStore";
import { Packet, MessageStart } from "@/app/app/services/streamingModels";
import useAgentPreferences from "@/hooks/useAgentPreferences";
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
import { useForcedTools } from "@/lib/hooks/useForcedTools";
import { ProjectFile, useProjectsContext } from "@/providers/ProjectsContext";
import { useAppParams } from "@/hooks/appNavigation";
@@ -94,6 +97,8 @@ export interface OnSubmitProps {
regenerationRequest?: RegenerationRequest | null;
// Additional context injected into the LLM call but not stored/shown in chat.
additionalContext?: string;
// Multi-model chat: up to 3 models selected for parallel comparison.
selectedModels?: SelectedModel[];
}
interface RegenerationRequest {
@@ -370,7 +375,10 @@ export default function useChatController({
modelOverride,
regenerationRequest,
additionalContext,
selectedModels,
}: OnSubmitProps) => {
// Check if this is multi-model mode (2 or 3 models selected)
const isMultiModelMode = selectedModels && selectedModels.length >= 2;
const projectId = params(SEARCH_PARAM_NAMES.PROJECT_ID);
{
const params = new URLSearchParams(searchParams?.toString() || "");
@@ -601,6 +609,7 @@ export default function useChatController({
// immediately reflects the user message
let initialUserNode: Message;
let initialAgentNode: Message;
let initialAssistantNodes: Message[] = [];
if (regenerationRequest) {
// For regeneration: keep the existing user message, only create new agent
@@ -623,12 +632,30 @@ export default function useChatController({
);
initialUserNode = result.initialUserNode;
initialAgentNode = result.initialAgentNode;
// In multi-model mode, create N assistant nodes (one per selected model)
if (isMultiModelMode && selectedModels) {
for (let i = 0; i < selectedModels.length; i++) {
initialAssistantNodes.push(
buildEmptyMessage({
messageType: "assistant",
parentNodeId: initialUserNode.nodeId,
nodeIdOffset: i + 1,
})
);
}
}
}
// make messages appear + clear input bar
const messagesToUpsert = regenerationRequest
? [initialAgentNode] // Only upsert the new agent for regeneration
: [initialUserNode, initialAgentNode]; // Upsert both for normal/edit flow
let messagesToUpsert: Message[];
if (regenerationRequest) {
messagesToUpsert = [initialAgentNode];
} else if (isMultiModelMode) {
messagesToUpsert = [initialUserNode, ...initialAssistantNodes];
} else {
messagesToUpsert = [initialUserNode, initialAgentNode];
}
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: messagesToUpsert,
completeMessageTreeOverride: currentMessageTreeLocal,
@@ -662,6 +689,24 @@ export default function useChatController({
let newUserMessageId: number | null = null;
let newAgentMessageId: number | null = null;
// Multi-model mode state tracking (dynamically sized based on selected models)
const numModels = selectedModels?.length ?? 0;
let newAssistantMessageIds: (number | null)[] = isMultiModelMode
? Array(numModels).fill(null)
: [];
let packetsPerModel: Packet[][] = isMultiModelMode
? Array.from({ length: numModels }, () => [])
: [];
let modelDisplayNames: string[] = isMultiModelMode
? selectedModels?.map((m) => m.displayName) ?? []
: [];
let documentsPerModel: OnyxDocument[][] = isMultiModelMode
? Array.from({ length: numModels }, () => [])
: [];
let citationsPerModel: (CitationMap | null)[] = isMultiModelMode
? Array(numModels).fill(null)
: [];
try {
const lastSuccessfulMessageId = getLastSuccessfulMessageId(
currentMessageTreeLocal
@@ -710,13 +755,15 @@ export default function useChatController({
filterManager.timeRange,
filterManager.selectedTags
),
modelProvider:
modelOverride?.name || llmManager.currentLlm.name || undefined,
modelVersion:
modelOverride?.modelName ||
llmManager.currentLlm.modelName ||
searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
undefined,
modelProvider: isMultiModelMode
? undefined
: modelOverride?.name || llmManager.currentLlm.name || undefined,
modelVersion: isMultiModelMode
? undefined
: modelOverride?.modelName ||
llmManager.currentLlm.modelName ||
searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
undefined,
temperature: llmManager.temperature || undefined,
deepResearch,
enabledToolIds:
@@ -728,6 +775,12 @@ export default function useChatController({
forcedToolId: effectiveForcedToolId,
origin: messageOrigin,
additionalContext,
llmOverrides: isMultiModelMode
? selectedModels!.map((model) => ({
model_provider: model.name,
model_version: model.modelName,
}))
: undefined,
});
const delay = (ms: number) => {
@@ -780,6 +833,26 @@ export default function useChatController({
.reserved_assistant_message_id;
}
// Multi-model: handle reserved IDs for N parallel model responses
if (
isMultiModelMode &&
Object.hasOwn(packet, "reserved_assistant_message_ids") &&
Array.isArray(
(packet as MultiModelMessageResponseIDInfo)
.reserved_assistant_message_ids
)
) {
const multiPacket = packet as MultiModelMessageResponseIDInfo;
newAssistantMessageIds =
multiPacket.reserved_assistant_message_ids;
newUserMessageId =
multiPacket.user_message_id ?? newUserMessageId;
// Capture backend model names for display on reload
if (multiPacket.model_names?.length) {
modelDisplayNames = multiPacket.model_names;
}
}
if (Object.hasOwn(packet, "user_files")) {
const userFiles = (packet as UserKnowledgeFilePacket).user_files;
// Ensure files are unique by id
@@ -823,32 +896,73 @@ export default function useChatController({
updateCanContinue(true, frozenSessionId);
}
} else if (Object.hasOwn(packet, "obj")) {
packets.push(packet as Packet);
packetsVersion++;
const typedPacket = packet as Packet;
// Check if the packet contains document information
const packetObj = (packet as Packet).obj;
// In multi-model mode, route packets by model_index
if (isMultiModelMode) {
const modelIndex = typedPacket.placement?.model_index ?? 0;
if (
modelIndex >= 0 &&
modelIndex < packetsPerModel.length &&
packetsPerModel[modelIndex]
) {
packetsPerModel[modelIndex] = [
...packetsPerModel[modelIndex]!,
typedPacket,
];
if (packetObj.type === "citation_info") {
// Individual citation packet from backend streaming
const citationInfo = packetObj as {
type: "citation_info";
citation_number: number;
document_id: string;
};
// Incrementally build citations map
citations = {
...(citations || {}),
[citationInfo.citation_number]: citationInfo.document_id,
};
} else if (packetObj.type === "message_start") {
const messageStart = packetObj as MessageStart;
if (messageStart.final_documents) {
documents = messageStart.final_documents;
updateSelectedNodeForDocDisplay(
frozenSessionId,
initialAgentNode.nodeId
);
const packetObj = typedPacket.obj;
if (packetObj.type === "citation_info") {
const citationInfo = packetObj as {
type: "citation_info";
citation_number: number;
document_id: string;
};
citationsPerModel[modelIndex] = {
...(citationsPerModel[modelIndex] || {}),
[citationInfo.citation_number]: citationInfo.document_id,
};
} else if (packetObj.type === "message_start") {
const messageStart = packetObj as MessageStart;
if (messageStart.final_documents) {
documentsPerModel[modelIndex] =
messageStart.final_documents;
if (modelIndex === 0 && initialAssistantNodes[0]) {
updateSelectedNodeForDocDisplay(
frozenSessionId,
initialAssistantNodes[0].nodeId
);
}
}
}
}
} else {
// Single model mode
packets.push(typedPacket);
packetsVersion++;
const packetObj = typedPacket.obj;
if (packetObj.type === "citation_info") {
const citationInfo = packetObj as {
type: "citation_info";
citation_number: number;
document_id: string;
};
citations = {
...(citations || {}),
[citationInfo.citation_number]: citationInfo.document_id,
};
} else if (packetObj.type === "message_start") {
const messageStart = packetObj as MessageStart;
if (messageStart.final_documents) {
documents = messageStart.final_documents;
updateSelectedNodeForDocDisplay(
frozenSessionId,
initialAgentNode.nodeId
);
}
}
}
} else {
@@ -860,8 +974,48 @@ export default function useChatController({
parentMessage =
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: [
// Build messages to upsert based on mode
let messagesToUpsertInLoop: Message[];
if (isMultiModelMode) {
// Multi-model mode: update user node + all N assistant nodes
const updatedUserNode = {
...initialUserNode,
messageId: newUserMessageId ?? undefined,
files: files,
};
const updatedAssistantNodes = initialAssistantNodes.map(
(node, idx) => ({
...node,
messageId: newAssistantMessageIds[idx] ?? undefined,
message: "",
type: "assistant" as const,
retrievalType,
query: query,
documents: documentsPerModel[idx] || [],
citations: citationsPerModel[idx] || {},
files: [] as FileDescriptor[],
toolCall: null,
stackTrace: null,
overridden_model: selectedModels?.[idx]?.displayName,
modelDisplayName:
modelDisplayNames[idx] ||
selectedModels?.[idx]?.displayName ||
null,
stopReason: stopReason,
packets: packetsPerModel[idx] || [],
packetCount: packetsPerModel[idx]?.length || 0,
})
);
messagesToUpsertInLoop = [
updatedUserNode,
...updatedAssistantNodes,
];
} else {
// Single model mode (existing logic)
messagesToUpsertInLoop = [
{
...initialUserNode,
messageId: newUserMessageId ?? undefined,
@@ -894,8 +1048,11 @@ export default function useChatController({
: undefined;
})(),
},
],
// Pass the latest map state
];
}
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: messagesToUpsertInLoop,
completeMessageTreeOverride: currentMessageTreeLocal,
chatSessionId: frozenSessionId!,
});

View File

@@ -0,0 +1,232 @@
import { renderHook, act } from "@testing-library/react";
import useMultiModelChat from "@/hooks/useMultiModelChat";
import { LlmManager } from "@/lib/hooks";
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
// Mock buildLlmOptions — hook uses it internally for initialization.
// Tests here focus on CRUD operations, not the initialization side-effect.
jest.mock("@/refresh-components/popovers/LLMPopover", () => ({
buildLlmOptions: jest.fn(() => []),
}));
const makeLlmManager = (): LlmManager =>
({
llmProviders: [],
currentLlm: { modelName: null, provider: null },
isLoadingProviders: false,
}) as unknown as LlmManager;
const makeModel = (provider: string, modelName: string): SelectedModel => ({
name: provider,
provider,
modelName,
displayName: `${provider}/${modelName}`,
});
const GPT4 = makeModel("openai", "gpt-4");
const CLAUDE = makeModel("anthropic", "claude-opus-4-6");
const GEMINI = makeModel("google", "gemini-pro");
const GPT4_TURBO = makeModel("openai", "gpt-4-turbo");
// ---------------------------------------------------------------------------
// addModel
// ---------------------------------------------------------------------------
describe("addModel", () => {
it("adds a model to an empty selection", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
expect(result.current.selectedModels).toHaveLength(1);
expect(result.current.selectedModels[0]).toEqual(GPT4);
});
it("does not add a duplicate model", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(GPT4); // duplicate
});
expect(result.current.selectedModels).toHaveLength(1);
});
it("enforces MAX_MODELS (3) cap", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
result.current.addModel(GEMINI);
result.current.addModel(GPT4_TURBO); // should be ignored
});
expect(result.current.selectedModels).toHaveLength(3);
});
});
// ---------------------------------------------------------------------------
// removeModel
// ---------------------------------------------------------------------------
describe("removeModel", () => {
it("removes a model by index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.removeModel(0); // remove GPT4
});
expect(result.current.selectedModels).toHaveLength(1);
expect(result.current.selectedModels[0]).toEqual(CLAUDE);
});
it("handles out-of-range index gracefully", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
act(() => {
result.current.removeModel(99); // no-op
});
expect(result.current.selectedModels).toHaveLength(1);
});
});
// ---------------------------------------------------------------------------
// replaceModel
// ---------------------------------------------------------------------------
describe("replaceModel", () => {
it("replaces the model at the given index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.replaceModel(0, GEMINI);
});
expect(result.current.selectedModels[0]).toEqual(GEMINI);
expect(result.current.selectedModels[1]).toEqual(CLAUDE);
});
it("does not replace with a model already selected at another index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.replaceModel(0, CLAUDE); // CLAUDE is already at index 1
});
// Should be a no-op — GPT4 stays at index 0
expect(result.current.selectedModels[0]).toEqual(GPT4);
});
});
// ---------------------------------------------------------------------------
// isMultiModelActive
// ---------------------------------------------------------------------------
describe("isMultiModelActive", () => {
it("is false with zero models", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
expect(result.current.isMultiModelActive).toBe(false);
});
it("is false with exactly one model", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
expect(result.current.isMultiModelActive).toBe(false);
});
it("is true with two or more models", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
expect(result.current.isMultiModelActive).toBe(true);
});
});
// ---------------------------------------------------------------------------
// buildLlmOverrides
// ---------------------------------------------------------------------------
describe("buildLlmOverrides", () => {
it("returns empty array when no models selected", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
expect(result.current.buildLlmOverrides()).toEqual([]);
});
it("maps selectedModels to LLMOverride format", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
const overrides = result.current.buildLlmOverrides();
expect(overrides).toHaveLength(2);
expect(overrides[0]).toEqual({
model_provider: "openai",
model_version: "gpt-4",
display_name: "openai/gpt-4",
});
expect(overrides[1]).toEqual({
model_provider: "anthropic",
model_version: "claude-opus-4-6",
display_name: "anthropic/claude-opus-4-6",
});
});
});
// ---------------------------------------------------------------------------
// clearModels
// ---------------------------------------------------------------------------
describe("clearModels", () => {
it("empties the selection", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.clearModels();
});
expect(result.current.selectedModels).toHaveLength(0);
expect(result.current.isMultiModelActive).toBe(false);
});
});

View File

@@ -0,0 +1,191 @@
"use client";
import { useState, useCallback, useEffect, useMemo } from "react";
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
import { LLMOverride } from "@/app/app/services/lib";
import { LlmManager } from "@/lib/hooks";
import { buildLlmOptions } from "@/refresh-components/popovers/LLMPopover";
const MAX_MODELS = 3;
export interface UseMultiModelChatReturn {
/** Currently selected models for multi-model comparison. */
selectedModels: SelectedModel[];
/** Whether multi-model mode is active (>1 model selected). */
isMultiModelActive: boolean;
/** Add a model to the selection. */
addModel: (model: SelectedModel) => void;
/** Remove a model by index. */
removeModel: (index: number) => void;
/** Replace a model at a specific index with a new one. */
replaceModel: (index: number, model: SelectedModel) => void;
/** Clear all selected models. */
clearModels: () => void;
/** Build the LLMOverride[] array from selectedModels. */
buildLlmOverrides: () => LLMOverride[];
/**
* Restore multi-model selection from model version strings (e.g. from chat history).
* Matches against available llmOptions to reconstruct full SelectedModel objects.
*/
restoreFromModelNames: (modelNames: string[]) => void;
/**
* Switch to a single model by name (after user picks a preferred response).
* Matches against llmOptions to find the full SelectedModel.
*/
selectSingleModel: (modelName: string) => void;
}
export default function useMultiModelChat(
llmManager: LlmManager
): UseMultiModelChatReturn {
const [selectedModels, setSelectedModels] = useState<SelectedModel[]>([]);
const [defaultInitialized, setDefaultInitialized] = useState(false);
// Initialize with the default model from llmManager once providers load
const llmOptions = useMemo(
() =>
llmManager.llmProviders ? buildLlmOptions(llmManager.llmProviders) : [],
[llmManager.llmProviders]
);
useEffect(() => {
if (defaultInitialized) return;
if (llmOptions.length === 0) return;
const { currentLlm } = llmManager;
// Don't initialize if currentLlm hasn't loaded yet
if (!currentLlm.modelName) return;
const match = llmOptions.find(
(opt) =>
opt.provider === currentLlm.provider &&
opt.modelName === currentLlm.modelName
);
if (match) {
setSelectedModels([
{
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
},
]);
setDefaultInitialized(true);
}
}, [llmOptions, llmManager.currentLlm, defaultInitialized, llmManager]);
const isMultiModelActive = selectedModels.length > 1;
const addModel = useCallback((model: SelectedModel) => {
setSelectedModels((prev) => {
if (prev.length >= MAX_MODELS) return prev;
if (
prev.some(
(m) =>
m.provider === model.provider && m.modelName === model.modelName
)
) {
return prev;
}
return [...prev, model];
});
}, []);
const removeModel = useCallback((index: number) => {
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
}, []);
const replaceModel = useCallback((index: number, model: SelectedModel) => {
setSelectedModels((prev) => {
// Don't replace with a model that's already selected elsewhere
if (
prev.some(
(m, i) =>
i !== index &&
m.provider === model.provider &&
m.modelName === model.modelName
)
) {
return prev;
}
const next = [...prev];
next[index] = model;
return next;
});
}, []);
const clearModels = useCallback(() => {
setSelectedModels([]);
}, []);
const restoreFromModelNames = useCallback(
(modelNames: string[]) => {
if (modelNames.length < 2 || llmOptions.length === 0) return;
const restored: SelectedModel[] = [];
for (const name of modelNames) {
// Try matching by modelName (raw version string like "claude-opus-4-6")
// or by displayName (friendly name like "Claude Opus 4.6")
const match = llmOptions.find(
(opt) =>
opt.modelName === name ||
opt.displayName === name ||
opt.name === name
);
if (match) {
restored.push({
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
});
}
}
if (restored.length >= 2) {
setSelectedModels(restored);
setDefaultInitialized(true);
}
},
[llmOptions]
);
const selectSingleModel = useCallback(
(modelName: string) => {
if (llmOptions.length === 0) return;
const match = llmOptions.find(
(opt) =>
opt.modelName === modelName ||
opt.displayName === modelName ||
opt.name === modelName
);
if (match) {
setSelectedModels([
{
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
},
]);
}
},
[llmOptions]
);
const buildLlmOverrides = useCallback((): LLMOverride[] => {
return selectedModels.map((m) => ({
model_provider: m.name,
model_version: m.modelName,
display_name: m.displayName,
}));
}, [selectedModels]);
return {
selectedModels,
isMultiModelActive,
addModel,
removeModel,
replaceModel,
clearModels,
buildLlmOverrides,
restoreFromModelNames,
selectSingleModel,
};
}

View File

@@ -32,6 +32,12 @@ export interface FrostedDivProps extends React.HTMLAttributes<HTMLDivElement> {
* Additional classes for the frost overlay element itself
*/
overlayClassName?: string;
/**
* Additional classes for the outermost wrapper div (the `relative` container).
* Useful for width propagation (e.g. `w-full`).
*/
wrapperClassName?: string;
}
/**
@@ -60,13 +66,14 @@ export default function FrostedDiv({
backdropBlur = "6px",
borderRadius = "1rem",
overlayClassName,
wrapperClassName,
className,
style,
children,
...props
}: FrostedDivProps) {
return (
<div className="relative">
<div className={cn("relative", wrapperClassName)}>
{/* Frost effect overlay - positioned behind content with bloom extending outward */}
<div
className={cn("absolute pointer-events-none", overlayClassName)}

View File

@@ -0,0 +1,469 @@
"use client";
import { useState, useMemo, useRef } from "react";
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
import { LlmManager } from "@/lib/hooks";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import Text from "@/refresh-components/texts/Text";
import { Button } from "@opal/components";
import {
SvgCheck,
SvgChevronDown,
SvgChevronRight,
SvgColumn,
SvgPlusCircle,
SvgX,
} from "@opal/icons";
import { Section } from "@/layouts/general-layouts";
import {
LLMOption,
LLMOptionGroup,
} from "@/refresh-components/popovers/interfaces";
import {
buildLlmOptions,
groupLlmOptions,
} from "@/refresh-components/popovers/LLMPopover";
import {
Accordion,
AccordionContent,
AccordionItem,
AccordionTrigger,
} from "@/components/ui/accordion";
import { cn } from "@/lib/utils";
const MAX_MODELS = 3;
export interface SelectedModel {
name: string;
provider: string;
modelName: string;
displayName: string;
}
export interface ModelSelectorProps {
llmManager: LlmManager;
selectedModels: SelectedModel[];
onAdd: (model: SelectedModel) => void;
onRemove: (index: number) => void;
onReplace: (index: number, model: SelectedModel) => void;
}
/** Vertical 1px divider between model bar elements */
function BarDivider() {
return <div className="h-9 w-px bg-border-01 shrink-0" />;
}
/** Individual model pill in the model bar */
function ModelPill({
model,
isMultiModel,
onRemove,
onClick,
}: {
model: SelectedModel;
isMultiModel: boolean;
onRemove?: () => void;
onClick?: () => void;
}) {
const ProviderIcon = getProviderIcon(model.provider, model.modelName);
return (
<div
role="button"
tabIndex={0}
onClick={onClick}
onKeyDown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
onClick?.();
}
}}
className={cn(
"flex items-center gap-0.5 rounded-12 p-2 shrink-0 cursor-pointer",
"hover:bg-background-tint-02 transition-colors",
isMultiModel && "bg-background-tint-02"
)}
>
<div className="flex items-center justify-center size-5 shrink-0 p-0.5">
<ProviderIcon size={16} />
</div>
<Text mainUiAction text04 nowrap className="px-1">
{model.displayName}
</Text>
{isMultiModel ? (
<button
type="button"
onClick={(e) => {
e.stopPropagation();
onRemove?.();
}}
className="flex items-center justify-center size-4 shrink-0 hover:opacity-70"
>
<SvgX className="size-4 stroke-text-03" />
</button>
) : (
<SvgChevronDown className="size-4 stroke-text-03 shrink-0" />
)}
</div>
);
}
/** Model item row inside the add-model popover */
function ModelItem({
option,
isSelected,
isDisabled,
onToggle,
}: {
option: LLMOption;
isSelected: boolean;
isDisabled: boolean;
onToggle: () => void;
}) {
const ProviderIcon = getProviderIcon(option.provider, option.modelName);
// Build subtitle from model capabilities
const subtitle = useMemo(() => {
const parts: string[] = [];
if (option.supportsReasoning) parts.push("reasoning");
if (option.supportsImageInput) parts.push("multi-modal");
if (parts.length === 0 && option.modelName) return option.modelName;
return parts.join(", ");
}, [option]);
return (
<button
type="button"
disabled={isDisabled}
onClick={onToggle}
className={cn(
"flex items-center gap-1.5 w-full rounded-08 p-1.5 text-left transition-colors",
isSelected ? "bg-action-link-01" : "hover:bg-background-tint-02",
isDisabled && !isSelected && "opacity-50 cursor-not-allowed"
)}
>
<div className="flex items-center justify-center size-5 shrink-0 p-0.5">
<ProviderIcon size={16} />
</div>
<div className="flex flex-col flex-1 min-w-0">
<Text
mainUiAction
nowrap
className={cn(isSelected ? "text-action-link-03" : "text-text-04")}
>
{option.displayName}
</Text>
{subtitle && (
<Text secondaryBody text03 nowrap>
{subtitle}
</Text>
)}
</div>
{isSelected && (
<Text secondaryBody nowrap className="text-action-link-05 shrink-0">
Added
</Text>
)}
</button>
);
}
export default function ModelSelector({
llmManager,
selectedModels,
onAdd,
onRemove,
onReplace,
}: ModelSelectorProps) {
const [open, setOpen] = useState(false);
const [searchQuery, setSearchQuery] = useState("");
const scrollContainerRef = useRef<HTMLDivElement>(null);
// null = add mode (via + button), number = replace mode (via pill click)
const [replacingIndex, setReplacingIndex] = useState<number | null>(null);
const isMultiModel = selectedModels.length > 1;
const atMax = selectedModels.length >= MAX_MODELS;
const llmOptions = useMemo(
() => buildLlmOptions(llmManager.llmProviders),
[llmManager.llmProviders]
);
const selectedKeys = useMemo(
() => new Set(selectedModels.map((m) => `${m.provider}:${m.modelName}`)),
[selectedModels]
);
const filteredOptions = useMemo(() => {
if (!searchQuery.trim()) return llmOptions;
const query = searchQuery.toLowerCase();
return llmOptions.filter(
(opt) =>
opt.displayName.toLowerCase().includes(query) ||
opt.modelName.toLowerCase().includes(query) ||
(opt.vendor && opt.vendor.toLowerCase().includes(query))
);
}, [llmOptions, searchQuery]);
const groupedOptions = useMemo(
() => groupLlmOptions(filteredOptions),
[filteredOptions]
);
const isSearching = searchQuery.trim().length > 0;
// In replace mode, other selected models (not the one being replaced) are disabled
const otherSelectedKeys = useMemo(() => {
if (replacingIndex === null) return new Set<string>();
return new Set(
selectedModels
.filter((_, i) => i !== replacingIndex)
.map((m) => `${m.provider}:${m.modelName}`)
);
}, [selectedModels, replacingIndex]);
// Current model at the replacing index (shows as "selected" in replace mode)
const replacingKey = useMemo(() => {
if (replacingIndex === null) return null;
const m = selectedModels[replacingIndex];
return m ? `${m.provider}:${m.modelName}` : null;
}, [selectedModels, replacingIndex]);
const getItemState = (optKey: string) => {
if (replacingIndex !== null) {
// Replace mode
return {
isSelected: optKey === replacingKey,
isDisabled: otherSelectedKeys.has(optKey),
};
}
// Add mode
return {
isSelected: selectedKeys.has(optKey),
isDisabled: !selectedKeys.has(optKey) && atMax,
};
};
const handleSelectModel = (option: LLMOption) => {
const model: SelectedModel = {
name: option.name,
provider: option.provider,
modelName: option.modelName,
displayName: option.displayName,
};
if (replacingIndex !== null) {
// Replace mode: swap the model at the clicked pill's index
onReplace(replacingIndex, model);
setOpen(false);
setReplacingIndex(null);
setSearchQuery("");
return;
}
// Add mode: toggle (add/remove)
const key = `${option.provider}:${option.modelName}`;
const existingIndex = selectedModels.findIndex(
(m) => `${m.provider}:${m.modelName}` === key
);
if (existingIndex >= 0) {
onRemove(existingIndex);
} else if (!atMax) {
onAdd(model);
}
};
const handleOpenChange = (nextOpen: boolean) => {
setOpen(nextOpen);
if (!nextOpen) {
setReplacingIndex(null);
setSearchQuery("");
}
};
const handlePillClick = (index: number) => {
setReplacingIndex(index);
setOpen(true);
};
return (
<div className="flex items-center justify-end gap-1 p-1">
{/* (+) Add model button — hidden at max models */}
{!atMax && (
<Popover open={open} onOpenChange={handleOpenChange}>
<Popover.Trigger asChild>
<Button
prominence="tertiary"
icon={SvgPlusCircle}
size="sm"
tooltip="Add Model"
/>
</Popover.Trigger>
<Popover.Content side="top" align="start" width="lg">
<Section gap={0.25}>
<InputTypeIn
leftSearchIcon
variant="internal"
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
placeholder="Search models..."
/>
<PopoverMenu scrollContainerRef={scrollContainerRef}>
{groupedOptions.length === 0
? [
<div key="empty" className="py-3 px-2">
<Text secondaryBody text03>
No models found
</Text>
</div>,
]
: groupedOptions.length === 1
? [
<div key="single" className="flex flex-col gap-0.5">
{groupedOptions[0]!.options.map((opt) => {
const key = `${opt.provider}:${opt.modelName}`;
const state = getItemState(key);
return (
<ModelItem
key={opt.modelName}
option={opt}
isSelected={state.isSelected}
isDisabled={state.isDisabled}
onToggle={() => handleSelectModel(opt)}
/>
);
})}
</div>,
]
: [
<ModelGroupAccordion
key="accordion"
groups={groupedOptions}
isSearching={isSearching}
getItemState={getItemState}
onToggle={handleSelectModel}
/>,
]}
</PopoverMenu>
<div className="border-t border-border-01 mt-1 pt-1">
<button
type="button"
className="flex items-center gap-1.5 w-full rounded-08 p-1.5 text-left hover:bg-background-tint-02 transition-colors"
>
<SvgColumn className="size-5 stroke-text-03 shrink-0" />
<Text mainUiAction text04>
Compare Model
</Text>
</button>
</div>
</Section>
</Popover.Content>
</Popover>
)}
{/* Divider + model pills */}
{selectedModels.length > 0 && (
<>
<BarDivider />
{selectedModels.map((model, index) => (
<div
key={`${model.provider}:${model.modelName}`}
className="flex items-center gap-1"
>
{index > 0 && <BarDivider />}
<ModelPill
model={model}
isMultiModel={isMultiModel}
onRemove={() => onRemove(index)}
onClick={() => handlePillClick(index)}
/>
</div>
))}
</>
)}
</div>
);
}
interface ModelGroupAccordionProps {
groups: LLMOptionGroup[];
isSearching: boolean;
getItemState: (key: string) => { isSelected: boolean; isDisabled: boolean };
onToggle: (option: LLMOption) => void;
}
function ModelGroupAccordion({
groups,
isSearching,
getItemState,
onToggle,
}: ModelGroupAccordionProps) {
const allKeys = groups.map((g) => g.key);
const [expandedGroups, setExpandedGroups] = useState<string[]>([
allKeys[0] ?? "",
]);
const effectiveExpanded = isSearching ? allKeys : expandedGroups;
return (
<Accordion
type="multiple"
value={effectiveExpanded}
onValueChange={(value) => {
if (!isSearching) setExpandedGroups(value);
}}
className="w-full flex flex-col"
>
{groups.map((group) => {
const isExpanded = effectiveExpanded.includes(group.key);
return (
<AccordionItem
key={group.key}
value={group.key}
className="border-none pt-1"
>
<AccordionTrigger className="flex items-center rounded-08 hover:no-underline hover:bg-background-tint-02 group [&>svg]:hidden w-full py-1">
<div className="flex items-center gap-1 shrink-0">
<div className="flex items-center justify-center size-5 shrink-0">
<group.Icon size={16} />
</div>
<Text secondaryBody text03 nowrap className="px-0.5">
{group.displayName}
</Text>
</div>
<div className="flex-1" />
<div className="flex items-center justify-center size-6 shrink-0">
{isExpanded ? (
<SvgChevronDown className="h-4 w-4 stroke-text-04 shrink-0" />
) : (
<SvgChevronRight className="h-4 w-4 stroke-text-04 shrink-0" />
)}
</div>
</AccordionTrigger>
<AccordionContent className="pb-0 pt-0">
<div className="flex flex-col gap-0.5">
{group.options.map((opt) => {
const key = `${opt.provider}:${opt.modelName}`;
const state = getItemState(key);
return (
<ModelItem
key={key}
option={opt}
isSelected={state.isSelected}
isDisabled={state.isDisabled}
onToggle={() => onToggle(opt)}
/>
);
})}
</div>
</AccordionContent>
</AccordionItem>
);
})}
</Accordion>
);
}

View File

@@ -15,6 +15,8 @@ import {
} from "@/providers/SettingsProvider";
import Dropzone from "react-dropzone";
import AppInputBar, { AppInputBarHandle } from "@/sections/input/AppInputBar";
import ModelSelector from "@/refresh-components/popovers/ModelSelector";
import useMultiModelChat from "@/hooks/useMultiModelChat";
import useChatSessions from "@/hooks/useChatSessions";
import useCCPairs from "@/hooks/useCCPairs";
import useTags from "@/hooks/useTags";
@@ -64,6 +66,7 @@ import { SvgChevronDown, SvgFileText } from "@opal/icons";
import { Button } from "@opal/components";
import Spacer from "@/refresh-components/Spacer";
import useAppFocus from "@/hooks/useAppFocus";
import { useAppSidebarContext } from "@/providers/AppSidebarProvider";
import { useQueryController } from "@/providers/QueryControllerProvider";
import WelcomeMessage from "@/app/app/components/WelcomeMessage";
import ChatUI from "@/sections/chat/ChatUI";
@@ -364,6 +367,30 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
const autoScrollEnabled = user?.preferences?.auto_scroll !== false;
const isStreaming = currentChatState === "streaming";
const multiModel = useMultiModelChat(llmManager);
// Auto-fold sidebar when multi-model is active (needs full width)
const { folded: sidebarFolded, setFolded: setSidebarFolded } =
useAppSidebarContext();
const preMultiModelFoldedRef = useRef<boolean | null>(null);
useEffect(() => {
if (
multiModel.isMultiModelActive &&
preMultiModelFoldedRef.current === null
) {
preMultiModelFoldedRef.current = sidebarFolded;
setSidebarFolded(true);
} else if (
!multiModel.isMultiModelActive &&
preMultiModelFoldedRef.current !== null
) {
setSidebarFolded(preMultiModelFoldedRef.current);
preMultiModelFoldedRef.current = null;
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [multiModel.isMultiModelActive]);
const {
onSubmit,
stopGenerating,
@@ -463,6 +490,9 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
message,
currentMessageFiles,
deepResearch: deepResearchEnabledForCurrentWorkflow,
selectedModels: multiModel.isMultiModelActive
? multiModel.selectedModels
: undefined,
});
if (showOnboarding || !onboardingDismissed) {
finishOnboarding();
@@ -473,6 +503,8 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
onSubmit,
currentMessageFiles,
deepResearchEnabledForCurrentWorkflow,
multiModel.isMultiModelActive,
multiModel.selectedModels,
showOnboarding,
onboardingDismissed,
finishOnboarding,
@@ -511,6 +543,9 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
message,
currentMessageFiles,
deepResearch: deepResearchEnabledForCurrentWorkflow,
selectedModels: multiModel.isMultiModelActive
? multiModel.selectedModels
: undefined,
});
if (showOnboarding || !onboardingDismissed) {
finishOnboarding();
@@ -535,6 +570,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
showOnboarding,
onboardingDismissed,
finishOnboarding,
multiModel,
]
);
@@ -675,7 +711,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
style={gridStyle}
>
{/* ── Top row: ChatUI / WelcomeMessage / ProjectUI ── */}
<div className="row-start-1 min-h-0 overflow-hidden flex flex-col items-center">
<div className="row-start-1 min-h-0 overflow-y-hidden flex flex-col items-center">
{/* ChatUI */}
<Fade
show={
@@ -704,6 +740,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
stopGenerating={stopGenerating}
onResubmit={handleResubmitLastMessage}
anchorNodeId={anchorNodeId}
selectedModels={multiModel.selectedModels}
/>
</ChatScrollContainer>
</Fade>
@@ -730,6 +767,16 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
<WelcomeMessage
agent={liveAgent}
isDefaultAgent={isDefaultAgent}
hideTitle={multiModel.selectedModels.length >= 3}
rightChildren={
<ModelSelector
llmManager={llmManager}
selectedModels={multiModel.selectedModels}
onAdd={multiModel.addModel}
onRemove={multiModel.removeModel}
onReplace={multiModel.replaceModel}
/>
}
/>
<Spacer rem={1.5} />
</Fade>
@@ -791,6 +838,15 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
isSearch ? "h-[14px]" : "h-0"
)}
/>
{appFocus.isChat() && (
<ModelSelector
llmManager={llmManager}
selectedModels={multiModel.selectedModels}
onAdd={multiModel.addModel}
onRemove={multiModel.removeModel}
onReplace={multiModel.replaceModel}
/>
)}
<AppInputBar
ref={chatInputBarRef}
deepResearchEnabled={

View File

@@ -347,18 +347,19 @@ const ChatScrollContainer = React.memo(
const contentMask = buildContentMask();
return (
<div className="flex flex-col flex-1 min-h-0 w-full relative overflow-hidden mb-1">
<div className="flex flex-col flex-1 min-h-0 w-full relative overflow-y-hidden mb-1">
<div
key={sessionId}
ref={scrollContainerRef}
data-testid="chat-scroll-container"
className={cn(
"flex flex-col flex-1 min-h-0 overflow-y-auto overflow-x-hidden",
"flex flex-col flex-1 min-h-0 overflow-y-auto overflow-x-auto",
hideScrollbar ? "no-scrollbar" : "default-scrollbar"
)}
onScroll={handleScroll}
style={{
scrollbarGutter: "stable both-edges",
overflowX: "auto",
// Apply mask to fade content opacity at edges
maskImage: contentMask,
WebkitMaskImage: contentMask,
@@ -366,7 +367,7 @@ const ChatScrollContainer = React.memo(
>
<div
ref={contentWrapperRef}
className="w-full flex-1 flex flex-col items-center px-4"
className="min-w-full flex-1 flex flex-col items-center px-4"
data-scroll-ready={isScrollReady}
style={{
visibility: isScrollReady ? "visible" : "hidden",

View File

@@ -8,7 +8,10 @@ import { ErrorBanner } from "@/app/app/message/Resubmit";
import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
import { LlmDescriptor, LlmManager } from "@/lib/hooks";
import AgentMessage from "@/app/app/message/messageComponents/AgentMessage";
import Spacer from "@/refresh-components/Spacer";
import MultiModelResponseView, {
MultiModelResponse,
} from "@/app/app/message/MultiModelResponseView";
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
import DynamicBottomSpacer from "@/components/chat/DynamicBottomSpacer";
import {
useCurrentMessageHistory,
@@ -17,6 +20,8 @@ import {
useUncaughtError,
} from "@/app/app/stores/useChatSessionStore";
const MSG_MAX_W = "max-w-[720px] min-w-[400px]";
export interface ChatUIProps {
liveAgent: MinimalPersonaSnapshot;
llmManager: LlmManager;
@@ -37,6 +42,7 @@ export interface ChatUIProps {
forceSearch?: boolean;
};
forceSearch?: boolean;
selectedModels?: SelectedModel[];
}) => Promise<void>;
deepResearchEnabled: boolean;
currentMessageFiles: any[];
@@ -48,6 +54,9 @@ export interface ChatUIProps {
* Used by DynamicBottomSpacer to position the push-up effect.
*/
anchorNodeId?: number;
/** Currently selected models for multi-model comparison. */
selectedModels: SelectedModel[];
}
const ChatUI = React.memo(
@@ -62,6 +71,7 @@ const ChatUI = React.memo(
currentMessageFiles,
onResubmit,
anchorNodeId,
selectedModels,
}: ChatUIProps) => {
// Get messages and error state from store
const messages = useCurrentMessageHistory();
@@ -76,9 +86,11 @@ const ChatUI = React.memo(
const onSubmitRef = useRef(onSubmit);
const deepResearchEnabledRef = useRef(deepResearchEnabled);
const currentMessageFilesRef = useRef(currentMessageFiles);
const selectedModelsRef = useRef(selectedModels);
onSubmitRef.current = onSubmit;
deepResearchEnabledRef.current = deepResearchEnabled;
currentMessageFilesRef.current = currentMessageFiles;
selectedModelsRef.current = selectedModels;
const createRegenerator = useCallback(
(regenerationRequest: {
@@ -103,19 +115,72 @@ const ChatUI = React.memo(
const handleEditWithMessageId = useCallback(
(editedContent: string, msgId: number) => {
const models = selectedModelsRef.current;
onSubmitRef.current({
message: editedContent,
messageIdToResend: msgId,
currentMessageFiles: [],
deepResearch: deepResearchEnabledRef.current,
selectedModels: models.length >= 2 ? models : undefined,
});
},
[]
);
// Helper to check if a user message has multi-model responses
const getMultiModelResponses = useCallback(
(userMessage: Message): MultiModelResponse[] | null => {
if (!messageTree) return null;
const childrenNodeIds = userMessage.childrenNodeIds || [];
if (childrenNodeIds.length < 2) return null;
const childMessages = childrenNodeIds
.map((nodeId) => messageTree.get(nodeId))
.filter(
(msg): msg is Message =>
msg !== undefined && msg.type === "assistant"
);
if (childMessages.length < 2) return null;
// Distinguish multi-model from regenerations: multi-model messages
// have modelDisplayName or overridden_model set. Regenerations don't.
// During streaming, overridden_model is set. On reload, modelDisplayName is set.
const multiModelChildren = childMessages.filter(
(msg) => msg.modelDisplayName || msg.overridden_model
);
if (multiModelChildren.length < 2) return null;
const latestChildNodeId = userMessage.latestChildNodeId;
return childMessages.map((msg, idx) => {
// During streaming, overridden_model has the friendly display name.
// On reload from history, modelDisplayName has the DB-stored name.
const name = msg.overridden_model || msg.modelDisplayName || "Model";
return {
modelIndex: idx,
provider: "",
modelName: name,
displayName: name,
packets: msg.packets || [],
packetCount: msg.packetCount || msg.packets?.length || 0,
nodeId: msg.nodeId,
messageId: msg.messageId,
isHighlighted: msg.nodeId === latestChildNodeId,
currentFeedback: msg.currentFeedback,
isGenerating: msg.is_generating || false,
};
});
},
[messageTree]
);
return (
<>
<div className="flex flex-col w-full max-w-[var(--app-page-main-content-width)] h-full pt-4 pb-8 pr-1 gap-12">
{/* No max-width on container — individual messages control their own width.
Multi-model responses use full width while normal messages stay centered. */}
<div className="flex flex-col w-full h-full pt-4 pb-8 pr-1 gap-12">
{messages.map((message, i) => {
const messageReactComponentKey = `message-${message.nodeId}`;
const parentMessage = message.parentNodeId
@@ -125,32 +190,63 @@ const ChatUI = React.memo(
const nextMessage =
messages.length > i + 1 ? messages[i + 1] : null;
// Check for multi-model responses
const multiModelResponses = getMultiModelResponses(message);
return (
<div
id={messageReactComponentKey}
key={messageReactComponentKey}
className="flex flex-col gap-12 w-full"
>
<HumanMessage
disableSwitchingForStreaming={
(nextMessage && nextMessage.is_generating) || false
}
stopGenerating={stopGenerating}
content={message.message}
files={message.files}
messageId={message.messageId}
nodeId={message.nodeId}
onEdit={handleEditWithMessageId}
otherMessagesCanSwitchTo={
parentMessage?.childrenNodeIds ?? emptyChildrenIds
}
onMessageSelection={onMessageSelection}
/>
{/* Human message stays at normal chat width */}
<div className={`w-full ${MSG_MAX_W} self-center`}>
<HumanMessage
disableSwitchingForStreaming={
(nextMessage && nextMessage.is_generating) || false
}
stopGenerating={stopGenerating}
content={message.message}
files={message.files}
messageId={message.messageId}
nodeId={message.nodeId}
onEdit={handleEditWithMessageId}
otherMessagesCanSwitchTo={
parentMessage?.childrenNodeIds ?? emptyChildrenIds
}
onMessageSelection={onMessageSelection}
/>
</div>
{/* Multi-model response uses full width */}
{multiModelResponses && (
<MultiModelResponseView
responses={multiModelResponses}
chatState={{
agent: liveAgent,
docs: emptyDocs,
citations: undefined,
setPresentingDocument,
overriddenModel: llmManager.currentLlm?.modelName,
}}
llmManager={llmManager}
onRegenerate={createRegenerator}
parentMessage={message}
otherMessagesCanSwitchTo={
parentMessage?.childrenNodeIds ?? emptyChildrenIds
}
onMessageSelection={onMessageSelection}
/>
)}
</div>
);
} else if (message.type === "assistant") {
if ((error || loadError) && i === messages.length - 1) {
return (
<div key={`error-${message.nodeId}`} className="p-4">
<div
key={`error-${message.nodeId}`}
className={`p-4 w-full ${MSG_MAX_W} self-center`}
>
<ErrorBanner
resubmit={onResubmit}
error={error || loadError || ""}
@@ -164,6 +260,16 @@ const ChatUI = React.memo(
}
const previousMessage = i !== 0 ? messages[i - 1] : null;
// Check if this assistant message is part of a multi-model response
// If so, skip rendering since it's already rendered in MultiModelResponseView
if (
previousMessage?.type === "user" &&
getMultiModelResponses(previousMessage)
) {
return null;
}
const chatStateData = {
agent: liveAgent,
docs: message.documents ?? emptyDocs,
@@ -177,6 +283,7 @@ const ChatUI = React.memo(
<div
id={`message-${message.nodeId}`}
key={messageReactComponentKey}
className={`w-full ${MSG_MAX_W} self-center`}
>
<AgentMessage
rawPackets={message.packets}
@@ -206,7 +313,7 @@ const ChatUI = React.memo(
{(((error !== null || loadError !== null) &&
messages[messages.length - 1]?.type === "user") ||
messages[messages.length - 1]?.type === "error") && (
<div className="p-4">
<div className={`p-4 w-full ${MSG_MAX_W} self-center`}>
<ErrorBanner
resubmit={onResubmit}
error={error || loadError || ""}

View File

@@ -10,7 +10,6 @@ import React, {
} from "react";
import LineItem from "@/refresh-components/buttons/LineItem";
import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
import LLMPopover from "@/refresh-components/popovers/LLMPopover";
import { InputPrompt } from "@/app/app/interfaces";
import { FilterManager, LlmManager, useFederatedConnectors } from "@/lib/hooks";
import usePromptShortcuts from "@/hooks/usePromptShortcuts";
@@ -20,7 +19,7 @@ import { MinimalOnyxDocument } from "@/lib/search/interfaces";
import { ChatState } from "@/app/app/interfaces";
import { useForcedTools } from "@/lib/hooks/useForcedTools";
import useAppFocus from "@/hooks/useAppFocus";
import { cn, isImageFile } from "@/lib/utils";
import { cn } from "@/lib/utils";
import { Disabled } from "@opal/core";
import { useUser } from "@/providers/UserProvider";
import {
@@ -423,11 +422,6 @@ const AppInputBar = React.memo(
return currentMessageFiles.length > 1;
}, [currentMessageFiles]);
const hasImageFiles = useMemo(
() => currentMessageFiles.some((f) => isImageFile(f.name)),
[currentMessageFiles]
);
// Check if the agent has search tools available (internal search or web search)
// AND if deep research is globally enabled in admin settings
const showDeepResearch = useMemo(() => {
@@ -615,16 +609,6 @@ const AppInputBar = React.memo(
{/* Bottom right controls */}
<div className="flex flex-row items-center gap-1">
<div
data-testid="AppInputBar/llm-popover-trigger"
className={cn(controlsLoading && "invisible")}
>
<LLMPopover
llmManager={llmManager}
requiresImageInput={hasImageFiles}
disabled={disabled}
/>
</div>
{showMicButton &&
(sttEnabled ? (
<MicrophoneButton