Compare commits

..

4 Commits

29 changed files with 2509 additions and 152 deletions

View File

@@ -0,0 +1,36 @@
"""add preferred_response_id and model_display_name to chat_message
Revision ID: a3f8b2c1d4e5
Create Date: 2026-03-22
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a3f8b2c1d4e5"
down_revision = "b728689f45b1"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column(
"preferred_response_id",
sa.Integer(),
sa.ForeignKey("chat_message.id"),
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column("model_display_name", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "model_display_name")
op.drop_column("chat_message", "preferred_response_id")

View File

@@ -8,6 +8,7 @@ from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import InMemoryChatFile
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import Packet
@@ -35,7 +36,13 @@ class CreateChatSessionID(BaseModel):
chat_session_id: UUID
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
AnswerStreamPart = (
Packet
| MessageResponseIDInfo
| MultiModelMessageResponseIDInfo
| StreamingError
| CreateChatSessionID
)
AnswerStream = Iterator[AnswerStreamPart]

View File

@@ -4,9 +4,11 @@ An overview can be found in the README.md file in this directory.
"""
import io
import queue
import re
import traceback
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextvars import Token
from uuid import UUID
@@ -28,6 +30,7 @@ from onyx.chat.compression import calculate_total_history_tokens
from onyx.chat.compression import compress_chat_history
from onyx.chat.compression import find_summary_for_branch
from onyx.chat.compression import get_compression_params
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.chat.llm_loop import EmptyLLMResponseError
from onyx.chat.llm_loop import run_llm_loop
@@ -59,6 +62,8 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import reserve_multi_model_message_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -77,16 +82,20 @@ from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.override_models import LLMOverride
from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
from onyx.tools.constants import SEARCH_TOOL_ID
@@ -997,6 +1006,568 @@ def handle_stream_message_objects(
logger.exception("Error in setting processing status")
def _build_model_display_name(override: LLMOverride) -> str:
"""Build a human-readable display name from an LLM override."""
if override.display_name:
return override.display_name
if override.model_version:
return override.model_version
if override.model_provider:
return override.model_provider
return "unknown"
# Sentinel placed on the merged queue when a model thread finishes.
_MODEL_DONE = object()
class _ModelIndexEmitter(Emitter):
"""Emitter that tags packets with model_index and forwards directly to a shared queue.
Unlike the standard Emitter (which accumulates in a local bus), this puts
packets into the shared merged_queue in real-time as they're emitted. This
enables true parallel streaming — packets from multiple models interleave
on the wire instead of arriving in bursts after each model completes.
"""
def __init__(self, model_idx: int, merged_queue: queue.Queue) -> None:
super().__init__(queue.Queue()) # bus exists for compat, unused
self._model_idx = model_idx
self._merged_queue = merged_queue
def emit(self, packet: Packet) -> None:
tagged_placement = Placement(
turn_index=packet.placement.turn_index if packet.placement else 0,
tab_index=packet.placement.tab_index if packet.placement else 0,
sub_turn_index=(
packet.placement.sub_turn_index if packet.placement else None
),
model_index=self._model_idx,
)
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
self._merged_queue.put((self._model_idx, tagged_packet))
def run_multi_model_stream(
new_msg_req: SendMessageRequest,
user: User,
db_session: Session,
llm_overrides: list[LLMOverride],
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
mcp_headers: dict[str, str] | None = None,
) -> AnswerStream:
# TODO: The setup logic below (session resolution through tool construction)
# is duplicated from handle_stream_message_objects. Extract into a shared
# _ChatStreamContext dataclass + _prepare_chat_stream_context() factory so
# both paths call the same setup code. Tracked as follow-up refactor.
"""Run 2-3 LLMs in parallel and yield their packets tagged with model_index.
Resource management:
- Each model thread gets its OWN db_session (SQLAlchemy sessions are not thread-safe)
- The caller's db_session is used only for setup (before threads launch) and
completion callbacks (after threads finish)
- ThreadPoolExecutor is bounded to len(overrides) workers
- All threads are joined in the finally block regardless of success/failure
- Queue-based merging avoids busy-waiting
"""
n_models = len(llm_overrides)
if n_models < 2 or n_models > 3:
raise ValueError(f"Multi-model requires 2-3 overrides, got {n_models}")
if new_msg_req.deep_research:
raise ValueError("Multi-model is not supported with deep research")
tenant_id = get_current_tenant_id()
cache: CacheBackend | None = None
chat_session: ChatSession | None = None
user_id = user.id
if user.is_anonymous:
llm_user_identifier = "anonymous_user"
else:
llm_user_identifier = user.email or str(user_id)
try:
# ── Session setup (same as single-model path) ──────────────────
if not new_msg_req.chat_session_id:
if not new_msg_req.chat_session_info:
raise RuntimeError(
"Must specify a chat session id or chat session info"
)
chat_session = create_chat_session_from_request(
chat_session_request=new_msg_req.chat_session_info,
user_id=user_id,
db_session=db_session,
)
yield CreateChatSessionID(chat_session_id=chat_session.id)
else:
chat_session = get_chat_session_by_id(
chat_session_id=new_msg_req.chat_session_id,
user_id=user_id,
db_session=db_session,
)
persona = chat_session.persona
message_text = new_msg_req.message
# ── Build N LLM instances and validate costs ───────────────────
llms: list[LLM] = []
model_display_names: list[str] = []
for override in llm_overrides:
llm = get_llm_for_persona(
persona=persona,
user=user,
llm_override=override,
additional_headers=litellm_additional_headers,
)
check_llm_cost_limit_for_provider(
db_session=db_session,
tenant_id=tenant_id,
llm_provider_api_key=llm.config.api_key,
)
llms.append(llm)
model_display_names.append(_build_model_display_name(override))
# Use first LLM for token counting (context window is checked per-model
# but token counting is model-agnostic enough for setup purposes)
token_counter = get_llm_token_counter(llms[0])
verify_user_files(
user_files=new_msg_req.file_descriptors,
user_id=user_id,
db_session=db_session,
project_id=chat_session.project_id,
)
# ── Chat history chain (shared across all models) ──────────────
chat_history = create_chat_history_chain(
chat_session_id=chat_session.id, db_session=db_session
)
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
parent_message = chat_history[-1] if chat_history else root_message
elif (
new_msg_req.parent_message_id is None
or new_msg_req.parent_message_id == root_message.id
):
parent_message = root_message
chat_history = []
else:
parent_message = None
for i in range(len(chat_history) - 1, -1, -1):
if chat_history[i].id == new_msg_req.parent_message_id:
parent_message = chat_history[i]
chat_history = chat_history[: i + 1]
break
if parent_message is None:
raise ValueError(
"The new message sent is not on the latest mainline of messages"
)
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
message=message_text,
token_count=token_counter(message_text),
message_type=MessageType.USER,
files=new_msg_req.file_descriptors,
db_session=db_session,
commit=True,
)
chat_history.append(user_message)
available_files = _collect_available_file_ids(
chat_history=chat_history,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
summary_message = find_summary_for_branch(db_session, chat_history)
summarized_file_metadata: dict[str, FileToolMetadata] = {}
if summary_message and summary_message.last_summarized_message_id:
cutoff_id = summary_message.last_summarized_message_id
for msg in chat_history:
if msg.id > cutoff_id or not msg.files:
continue
for fd in msg.files:
file_id = fd.get("id")
if not file_id:
continue
summarized_file_metadata[file_id] = FileToolMetadata(
file_id=file_id,
filename=fd.get("name") or "unknown",
approx_char_count=0,
)
chat_history = [m for m in chat_history if m.id > cutoff_id]
user_memory_context = get_memories(user, db_session)
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
prompt_memory_context = (
user_memory_context
if user.use_memories
else user_memory_context.without_memories()
)
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
custom_agent_prompt or ""
)
reserved_token_count = calculate_reserved_tokens(
db_session=db_session,
persona_system_prompt=max_reserved_system_prompt_tokens_str,
token_counter=token_counter,
files=new_msg_req.file_descriptors,
user_memory_context=prompt_memory_context,
)
context_user_files = resolve_context_user_files(
persona=persona,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
# Use the smallest context window across all models for safety
min_context_window = min(llm.config.max_input_tokens for llm in llms)
extracted_context_files = extract_context_files(
user_files=context_user_files,
llm_max_context_window=min_context_window,
reserved_token_count=reserved_token_count,
db_session=db_session,
)
search_params = determine_search_params(
persona_id=persona.id,
project_id=chat_session.project_id,
extracted_context_files=extracted_context_files,
)
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
all_tools = get_tools(db_session)
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
search_tool_id = next(
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
None,
)
forced_tool_id = new_msg_req.forced_tool_id
if (
search_params.search_usage == SearchToolUsage.DISABLED
and forced_tool_id is not None
and search_tool_id is not None
and forced_tool_id == search_tool_id
):
forced_tool_id = None
files = load_all_chat_files(chat_history, db_session)
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
# ── Reserve N assistant message IDs ────────────────────────────
reserved_messages = reserve_multi_model_message_ids(
db_session=db_session,
chat_session_id=chat_session.id,
parent_message_id=user_message.id,
model_display_names=model_display_names,
)
yield MultiModelMessageResponseIDInfo(
user_message_id=user_message.id,
reserved_assistant_message_ids=[m.id for m in reserved_messages],
model_names=model_display_names,
)
has_file_reader_tool = any(
tool.in_code_tool_id == "file_reader" for tool in all_tools
)
chat_history_result = convert_chat_history(
chat_history=chat_history,
files=files,
context_image_files=extracted_context_files.image_files,
additional_context=new_msg_req.additional_context,
token_counter=token_counter,
tool_id_to_name_map=tool_id_to_name_map,
)
simple_chat_history = chat_history_result.simple_messages
all_injected_file_metadata: dict[str, FileToolMetadata] = (
chat_history_result.all_injected_file_metadata
if has_file_reader_tool
else {}
)
if summarized_file_metadata:
for fid, meta in summarized_file_metadata.items():
all_injected_file_metadata.setdefault(fid, meta)
if summary_message is not None:
summary_simple = ChatMessageSimple(
message=summary_message.message,
token_count=summary_message.token_count,
message_type=MessageType.ASSISTANT,
)
simple_chat_history.insert(0, summary_simple)
# ── Stop signal and processing status ──────────────────────────
cache = get_cache_backend()
reset_cancel_status(chat_session.id, cache)
def check_is_connected() -> bool:
return check_stop_signal(chat_session.id, cache)
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
value=True,
)
# Release the main session's read transaction before the long stream
db_session.commit()
# ── Parallel model execution ───────────────────────────────────
# Each model thread writes tagged packets to this shared queue.
# Sentinel _MODEL_DONE signals that a thread finished.
merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = (
queue.Queue()
)
# Track per-model state containers for completion callbacks
state_containers: list[ChatStateContainer] = [
ChatStateContainer() for _ in range(n_models)
]
# Track which models completed successfully (for completion callbacks)
model_succeeded: list[bool] = [False] * n_models
user_identity = LLMUserIdentity(
user_id=llm_user_identifier,
session_id=str(chat_session.id),
)
def _run_model(model_idx: int) -> None:
"""Run a single model in a worker thread.
Uses _ModelIndexEmitter so packets flow directly to merged_queue
in real-time (not batched after completion). This enables true
parallel streaming where both models' tokens interleave on the wire.
DB access: tools may need a session during execution (e.g., search
tool). Each thread creates its own session via context manager.
"""
model_emitter = _ModelIndexEmitter(model_idx, merged_queue)
sc = state_containers[model_idx]
model_llm = llms[model_idx]
try:
# Each model thread gets its own DB session for tool execution.
# The session is scoped to the thread and closed when done.
with get_session_with_current_tenant() as thread_db_session:
# Construct tools per-thread with thread-local DB session
thread_tool_dict = construct_tools(
persona=persona,
db_session=thread_db_session,
emitter=model_emitter,
user=user,
llm=model_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=new_msg_req.internal_search_filters,
project_id_filter=search_params.project_id_filter,
persona_id_filter=search_params.persona_id_filter,
bypass_acl=False,
enable_slack_search=_should_enable_slack_search(
persona, new_msg_req.internal_search_filters
),
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session.id,
message_id=user_message.id,
additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=available_files.user_file_ids,
chat_file_ids=available_files.chat_file_ids,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=search_params.search_usage,
)
model_tools: list[Tool] = []
for tool_list in thread_tool_dict.values():
model_tools.extend(tool_list)
# Run the LLM loop — this blocks until the model finishes.
# Packets flow to merged_queue in real-time via the emitter.
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=simple_chat_history,
tools=model_tools,
custom_agent_prompt=custom_agent_prompt,
context_files=extracted_context_files,
persona=persona,
user_memory_context=user_memory_context,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
forced_tool_id=forced_tool_id,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
chat_files=chat_files_for_tools,
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
model_succeeded[model_idx] = True
except Exception as e:
merged_queue.put((model_idx, e))
finally:
merged_queue.put((model_idx, _MODEL_DONE))
# Launch model threads via ThreadPoolExecutor (bounded, context-propagating)
executor = ThreadPoolExecutor(
max_workers=n_models,
thread_name_prefix="multi-model",
)
futures = []
try:
for i in range(n_models):
futures.append(executor.submit(_run_model, i))
# ── Main thread: merge and yield packets ───────────────────
models_remaining = n_models
while models_remaining > 0:
try:
model_idx, item = merged_queue.get(timeout=0.3)
except queue.Empty:
# Check cancellation during idle periods
if not check_is_connected():
yield Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
return
continue
if item is _MODEL_DONE:
models_remaining -= 1
continue
if isinstance(item, Exception):
# Yield error as a tagged StreamingError packet
error_msg = str(item)
stack_trace = "".join(
traceback.format_exception(type(item), item, item.__traceback__)
)
# Redact API keys from error messages
model_llm = llms[model_idx]
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
error_msg = error_msg.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
stack_trace = stack_trace.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(
error=error_msg,
stack_trace=stack_trace,
error_code="MODEL_ERROR",
is_retryable=True,
details={
"model": model_llm.config.model_name,
"provider": model_llm.config.model_provider,
"model_index": model_idx,
},
)
models_remaining -= 1
continue
if isinstance(item, Packet):
# Packet is already tagged with model_index by _ModelIndexEmitter
yield item
# ── Completion: save each successful model's response ──────
# Run completion callbacks on the main thread using the main
# session. This is safe because all worker threads have exited
# by this point (merged_queue fully drained).
for i in range(n_models):
if not model_succeeded[i]:
continue
try:
llm_loop_completion_handle(
state_container=state_containers[i],
is_connected=check_is_connected,
db_session=db_session,
assistant_message=reserved_messages[i],
llm=llms[i],
reserved_tokens=reserved_token_count,
)
except Exception:
logger.exception(
f"Failed completion for model {i} "
f"({model_display_names[i]})"
)
yield Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop", stop_reason="complete"),
)
finally:
# Ensure all threads are cleaned up regardless of how we exit
executor.shutdown(wait=True, cancel_futures=True)
except ValueError as e:
logger.exception("Failed to process multi-model chat message.")
yield StreamingError(
error=str(e),
error_code="VALIDATION_ERROR",
is_retryable=True,
)
db_session.rollback()
return
except Exception as e:
logger.exception(f"Failed multi-model chat: {e}")
stack_trace = traceback.format_exc()
yield StreamingError(
error=str(e),
stack_trace=stack_trace,
error_code="MULTI_MODEL_ERROR",
is_retryable=True,
)
db_session.rollback()
finally:
try:
if cache is not None and chat_session is not None:
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
value=False,
)
except Exception:
logger.exception("Error clearing processing status")
def llm_loop_completion_handle(
state_container: ChatStateContainer,
is_connected: Callable[[], bool],

View File

@@ -602,6 +602,79 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15,
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Set the preferred assistant response for a multi-model user message.
Validates that the user message is a USER type and that the preferred
assistant message is a direct child of that user message.
"""
user_msg = db_session.query(ChatMessage).get(user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.query(ChatMessage).get(preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -824,6 +897,8 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -2645,6 +2645,15 @@ class ChatMessage(Base):
nullable=True,
)
# For multi-model turns: the user message points to which assistant response
# was selected as the preferred one to continue the conversation with.
preferred_response_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id"), nullable=True
)
# The display name of the model that generated this assistant message
model_display_name: Mapped[str | None] = mapped_column(String, nullable=True)
# What does this message contain
reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True)
message: Mapped[str] = mapped_column(Text)
@@ -2712,6 +2721,12 @@ class ChatMessage(Base):
remote_side="ChatMessage.id",
)
preferred_response: Mapped["ChatMessage | None"] = relationship(
"ChatMessage",
foreign_keys=[preferred_response_id],
remote_side="ChatMessage.id",
)
# Chat messages only need to know their immediate tool call children
# If there are nested tool calls, they are stored in the tool_call_children relationship.
tool_calls: Mapped[list["ToolCall"] | None] = relationship(

View File

@@ -5,7 +5,6 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
out against a nonexistent Vespa/OpenSearch instance.
"""
from collections.abc import Iterable
from typing import Any
from onyx.context.search.models import IndexFilters
@@ -67,7 +66,7 @@ class DisabledDocumentIndex(DocumentIndex):
# ------------------------------------------------------------------
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
index_batch_params: IndexBatchParams, # noqa: ARG002
) -> set[DocumentInsertionRecord]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)

View File

@@ -1,5 +1,4 @@
import abc
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@@ -207,7 +206,7 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[DocumentInsertionRecord]:
"""
@@ -227,8 +226,8 @@ class Indexable(abc.ABC):
it is done automatically outside of this code.
Parameters:
- chunks: Document chunks with all of the information needed for
indexing to the document index.
- chunks: Document chunks with all of the information needed for indexing to the document
index.
- tenant_id: The tenant id of the user whose chunks are being indexed
- large_chunks_enabled: Whether large chunks are enabled

View File

@@ -1,5 +1,4 @@
import abc
from collections.abc import Iterable
from typing import Self
from pydantic import BaseModel
@@ -210,10 +209,10 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes an iterable of document chunks into the document index.
"""Indexes a list of document chunks into the document index.
This is often a batch operation including chunks from multiple
documents.

View File

@@ -1,5 +1,5 @@
import json
from collections.abc import Iterable
from collections import defaultdict
from typing import Any
import httpx
@@ -350,7 +350,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
@@ -646,10 +646,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata, # noqa: ARG002
) -> list[DocumentInsertionRecord]:
"""Indexes an iterable of document chunks into the document index.
"""Indexes a list of document chunks into the document index.
Groups chunks by document ID and for each document, deletes existing
chunks and indexes the new chunks in bulk.
@@ -672,34 +672,29 @@ class OpenSearchDocumentIndex(DocumentIndex):
document is newly indexed or had already existed and was just
updated.
"""
total_chunks = sum(
cc.new_chunk_cnt
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
# Group chunks by document ID.
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
list
)
for chunk in chunks:
doc_id_to_chunks[chunk.source_document.id].append(chunk)
logger.debug(
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
f"documents for index {self._index_name}."
)
document_indexing_results: list[DocumentInsertionRecord] = []
deleted_doc_ids: set[str] = set()
# Buffer chunks per document as they arrive from the iterable.
# When the document ID changes flush the buffered chunks.
current_doc_id: str | None = None
current_chunks: list[DocMetadataAwareIndexChunk] = []
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
assert len(doc_chunks) > 0, "doc_chunks is empty"
# Try to index per-document.
for _, chunks in doc_id_to_chunks.items():
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
# Since we are doing this in batches, an error occurring midway
# can result in a state where chunks are deleted and not all the
# new chunks have been indexed.
# Do this before deleting existing chunks to reduce the amount of
# time the document index has no content for a given document, and
# to reduce the chance of entering a state where we delete chunks,
# then some error happens, and never successfully index new chunks.
chunk_batch: list[DocumentChunk] = [
_convert_onyx_chunk_to_opensearch_document(chunk)
for chunk in doc_chunks
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
]
onyx_document: Document = doc_chunks[0].source_document
onyx_document: Document = chunks[0].source_document
# First delete the doc's chunks from the index. This is so that
# there are no dangling chunks in the index, in the event that the
# new document's content contains fewer chunks than the previous
@@ -708,40 +703,22 @@ class OpenSearchDocumentIndex(DocumentIndex):
# if the chunk count has actually decreased. This assumes that
# overlapping chunks are perfectly overwritten. If we can't
# guarantee that then we need the code as-is.
if onyx_document.id not in deleted_doc_ids:
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
deleted_doc_ids.add(onyx_document.id)
# If we see that chunks were deleted we assume the doc already
# existed. We record the result before bulk_index_documents
# runs. If indexing raises, this entire result list is discarded
# by the caller's retry logic, so early recording is safe.
document_indexing_results.append(
DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
)
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
# If we see that chunks were deleted we assume the doc already
# existed.
document_insertion_record = DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
for chunk in chunks:
doc_id = chunk.source_document.id
if doc_id != current_doc_id:
if current_chunks:
_flush_chunks(current_chunks)
current_doc_id = doc_id
current_chunks = [chunk]
else:
current_chunks.append(chunk)
if current_chunks:
_flush_chunks(current_chunks)
document_indexing_results.append(document_insertion_record)
return document_indexing_results

View File

@@ -6,7 +6,6 @@ import re
import time
import urllib
import zipfile
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
@@ -462,7 +461,7 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""

View File

@@ -1,8 +1,6 @@
import concurrent.futures
import logging
import random
from collections.abc import Generator
from collections.abc import Iterable
from typing import Any
from uuid import UUID
@@ -320,7 +318,7 @@ class VespaDocumentIndex(DocumentIndex):
def index(
self,
chunks: Iterable[DocMetadataAwareIndexChunk],
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
@@ -340,31 +338,22 @@ class VespaDocumentIndex(DocumentIndex):
# Vespa has restrictions on valid characters, yet document IDs come from
# external w.r.t. this class. We need to sanitize them.
#
# Instead of materializing all cleaned chunks upfront, we stream them
# through a generator that cleans IDs and builds the original-ID mapping
# incrementally as chunks flow into Vespa.
def _clean_and_track(
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
id_map: dict[str, str],
seen_ids: set[str],
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
"""Cleans chunk IDs and builds the original-ID mapping
incrementally as chunks flow through, avoiding a separate
materialization pass."""
for chunk in chunks_iter:
original_id = chunk.source_document.id
cleaned = clean_chunk_id_copy(chunk)
cleaned_id = cleaned.source_document.id
# Needed so the final DocumentInsertionRecord returned can have
# the original document ID. cleaned_chunks might not contain IDs
# exactly as callers supplied them.
id_map[cleaned_id] = original_id
seen_ids.add(cleaned_id)
yield cleaned
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
clean_chunk_id_copy(chunk) for chunk in chunks
]
assert len(cleaned_chunks) == len(
chunks
), "Bug: Cleaned chunks and input chunks have different lengths."
new_document_id_to_original_document_id: dict[str, str] = {}
all_cleaned_doc_ids: set[str] = set()
# Needed so the final DocumentInsertionRecord returned can have the
# original document ID. cleaned_chunks might not contain IDs exactly as
# callers supplied them.
new_document_id_to_original_document_id: dict[str, str] = dict()
for i, cleaned_chunk in enumerate(cleaned_chunks):
old_chunk = chunks[i]
new_document_id_to_original_document_id[
cleaned_chunk.source_document.id
] = old_chunk.source_document.id
existing_docs: set[str] = set()
@@ -420,13 +409,7 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
# Insert new Vespa documents, streaming through the cleaning
# pipeline so chunks are never fully materialized.
cleaned_chunks = _clean_and_track(
chunks,
new_document_id_to_original_document_id,
all_cleaned_doc_ids,
)
# Insert new Vespa documents.
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
batch_index_vespa_chunks(
chunks=chunk_batch,
@@ -436,6 +419,10 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
all_cleaned_doc_ids: set[str] = {
chunk.source_document.id for chunk in cleaned_chunks
}
return [
DocumentInsertionRecord(
document_id=new_document_id_to_original_document_id[cleaned_doc_id],

View File

@@ -11,6 +11,7 @@ class LLMOverride(BaseModel):
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None
display_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

View File

@@ -29,6 +29,7 @@ from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.process_message import run_multi_model_stream
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
from onyx.configs.app_configs import WEB_DOMAIN
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -81,6 +83,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -570,6 +573,38 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in run_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=chat_message_req.llm_overrides, # type: ignore[arg-type]
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -660,6 +695,26 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
_user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
try:
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -41,6 +41,16 @@ class MessageResponseIDInfo(BaseModel):
reserved_assistant_message_id: int
class MultiModelMessageResponseIDInfo(BaseModel):
"""Sent at the start of a multi-model streaming response.
Contains the user message ID and the reserved assistant message IDs
for each model being run in parallel."""
user_message_id: int | None
reserved_assistant_message_ids: list[int]
model_names: list[str]
class SourceTag(Tag):
source: DocumentSource
@@ -86,6 +96,9 @@ class SendMessageRequest(BaseModel):
message: str
llm_override: LLMOverride | None = None
# For multi-model mode: up to 3 LLM overrides to run in parallel.
# When provided with >1 entry, triggers multi-model streaming.
llm_overrides: list[LLMOverride] | None = None
# Test-only override for deterministic LiteLLM mock responses.
mock_llm_response: str | None = None
@@ -211,6 +224,8 @@ class ChatMessageDetail(BaseModel):
error: str | None = None
current_feedback: str | None = None # "like" | "dislike" | null
processing_duration_seconds: float | None = None
preferred_response_id: int | None = None
model_display_name: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
@@ -218,6 +233,11 @@ class ChatMessageDetail(BaseModel):
return initial_dict
class SetPreferredResponseRequest(BaseModel):
user_message_id: int
preferred_response_id: int
class ChatSessionDetailResponse(BaseModel):
chat_session_id: UUID
description: str | None

View File

@@ -8,3 +8,5 @@ class Placement(BaseModel):
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -6,7 +6,6 @@ These tests assume Vespa and OpenSearch are running.
import time
import uuid
from collections.abc import Generator
from collections.abc import Iterator
import httpx
import pytest
@@ -22,7 +21,6 @@ from onyx.document_index.opensearch.opensearch_document_index import (
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import DocMetadataAwareIndexChunk
from tests.external_dependency_unit.constants import TEST_TENANT_ID
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
from tests.external_dependency_unit.document_index.conftest import make_chunk
@@ -203,25 +201,3 @@ class TestDocumentIndexNew:
assert len(result_map) == 2
assert result_map[existing_doc] is True
assert result_map[new_doc] is False
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk(doc_id, chunk_id=i)
results = document_index.index(
chunks=chunk_gen(), indexing_metadata=metadata
)
assert len(results) == 1
assert results[0].document_id == doc_id
assert results[0].already_existed is False

View File

@@ -5,7 +5,6 @@ These tests assume Vespa and OpenSearch are running.
import time
from collections.abc import Generator
from collections.abc import Iterator
import pytest
@@ -167,29 +166,3 @@ class TestDocumentIndexOld:
batch_retrieval=True,
)
assert len(inference_chunks) == 0
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndex],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk("test_doc_gen", chunk_id=i)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
tenant_id=get_current_tenant_id(),
large_chunks_enabled=False,
)
results = document_index.index(chunk_gen(), index_batch_params)
assert len(results) == 1
record = results.pop()
assert record.document_id == "test_doc_gen"
assert record.already_existed is False

View File

@@ -0,0 +1,206 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in run_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
from typing import Any
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
"""Advance the generator one step to trigger early validation."""
from onyx.chat.process_message import run_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = run_multi_model_stream(req, user, db, overrides)
# Calling next() executes until the first yield OR raises.
# Validation errors are raised before any yield.
next(gen)
# ---------------------------------------------------------------------------
# run_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_raises(self) -> None:
"""Exactly 1 override is not multi-model — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
def test_four_overrides_raises(self) -> None:
"""4 overrides exceeds maximum — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
def test_zero_overrides_raises(self) -> None:
"""Empty override list raises."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [])
def test_deep_research_raises(self) -> None:
"""deep_research=True is incompatible with multi-model."""
req = _make_request(deep_research=True)
with pytest.raises(ValueError, match="not supported"):
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
req = _make_request()
# 1 override must fail
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
try:
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
except ValueError as exc:
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
except Exception:
pass # Any other error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.query.return_value.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.query.return_value.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.query.return_value.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"

View File

@@ -0,0 +1,134 @@
"""Unit tests for multi-model answer generation types.
Tests cover:
- Placement.model_index serialization
- MultiModelMessageResponseIDInfo round-trip
- SendMessageRequest.llm_overrides backward compatibility
- ChatMessageDetail new fields
"""
from uuid import uuid4
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
class TestPlacementModelIndex:
def test_default_none(self) -> None:
p = Placement(turn_index=0)
assert p.model_index is None
def test_set_value(self) -> None:
p = Placement(turn_index=0, model_index=2)
assert p.model_index == 2
def test_serializes(self) -> None:
p = Placement(turn_index=0, tab_index=1, model_index=1)
d = p.model_dump()
assert d["model_index"] == 1
def test_none_excluded_when_default(self) -> None:
p = Placement(turn_index=0)
d = p.model_dump()
assert d["model_index"] is None
class TestMultiModelMessageResponseIDInfo:
def test_round_trip(self) -> None:
info = MultiModelMessageResponseIDInfo(
user_message_id=42,
reserved_assistant_message_ids=[43, 44, 45],
model_names=["gpt-4", "claude-opus", "gemini-pro"],
)
d = info.model_dump()
restored = MultiModelMessageResponseIDInfo(**d)
assert restored.user_message_id == 42
assert restored.reserved_assistant_message_ids == [43, 44, 45]
assert restored.model_names == ["gpt-4", "claude-opus", "gemini-pro"]
def test_null_user_message_id(self) -> None:
info = MultiModelMessageResponseIDInfo(
user_message_id=None,
reserved_assistant_message_ids=[1, 2],
model_names=["a", "b"],
)
assert info.user_message_id is None
class TestSendMessageRequestOverrides:
def test_llm_overrides_default_none(self) -> None:
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
)
assert req.llm_overrides is None
def test_llm_overrides_accepts_list(self) -> None:
overrides = [
LLMOverride(model_provider="openai", model_version="gpt-4"),
LLMOverride(model_provider="anthropic", model_version="claude-opus"),
]
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
llm_overrides=overrides,
)
assert req.llm_overrides is not None
assert len(req.llm_overrides) == 2
def test_backward_compat_single_override(self) -> None:
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
llm_override=LLMOverride(model_provider="openai", model_version="gpt-4"),
)
assert req.llm_override is not None
assert req.llm_overrides is None
class TestChatMessageDetailMultiModel:
def test_defaults_none(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.ASSISTANT,
time_sent="2026-03-22T00:00:00Z",
files=[],
)
assert detail.preferred_response_id is None
assert detail.model_display_name is None
def test_set_values(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.USER,
time_sent="2026-03-22T00:00:00Z",
files=[],
preferred_response_id=42,
model_display_name="GPT-4",
)
assert detail.preferred_response_id == 42
assert detail.model_display_name == "GPT-4"
def test_serializes(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.ASSISTANT,
time_sent="2026-03-22T00:00:00Z",
files=[],
model_display_name="Claude Opus",
)
d = detail.model_dump()
assert d["model_display_name"] == "Claude Opus"
assert d["preferred_response_id"] is None

View File

@@ -159,6 +159,10 @@ export interface Message {
overridden_model?: string;
stopReason?: StreamStopReason | null;
// Multi-model answer generation
preferredResponseId?: number | null;
modelDisplayName?: string | null;
// new gen
packets: Packet[];
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
@@ -231,6 +235,9 @@ export interface BackendMessage {
parentMessageId: number | null;
refined_answer_improvement: boolean | null;
is_agentic: boolean | null;
// Multi-model answer generation
preferred_response_id: number | null;
model_display_name: string | null;
}
export interface MessageResponseIDInfo {
@@ -238,6 +245,12 @@ export interface MessageResponseIDInfo {
reserved_assistant_message_id: number; // TODO: rename to agent — https://linear.app/onyx-app/issue/ENG-3766
}
export interface MultiModelMessageResponseIDInfo {
user_message_id: number | null;
reserved_assistant_message_ids: number[];
model_names: string[];
}
export interface UserKnowledgeFilePacket {
user_files: FileDescriptor[];
}

View File

@@ -0,0 +1,149 @@
"use client";
import { useCallback } from "react";
import { Button } from "@opal/components";
import { Hoverable } from "@opal/core";
import { SvgEyeClosed, SvgMoreHorizontal, SvgX } from "@opal/icons";
import Text from "@/refresh-components/texts/Text";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import AgentMessage, {
AgentMessageProps,
} from "@/app/app/message/messageComponents/AgentMessage";
import { cn } from "@/lib/utils";
export interface MultiModelPanelProps {
/** Index of this model in the selectedModels array (used for Hoverable group key) */
modelIndex: number;
/** Provider name for icon lookup */
provider: string;
/** Model name for icon lookup and display */
modelName: string;
/** Display-friendly model name */
displayName: string;
/** Whether this panel is the preferred/selected response */
isPreferred: boolean;
/** Whether this panel is currently hidden */
isHidden: boolean;
/** Whether this is a non-preferred panel in selection mode (pushed off-screen) */
isNonPreferredInSelection: boolean;
/** Callback when user clicks this panel to select as preferred */
onSelect: () => void;
/** Callback to hide/show this panel */
onToggleVisibility: () => void;
/** Props to pass through to AgentMessage */
agentMessageProps: AgentMessageProps;
}
export default function MultiModelPanel({
modelIndex,
provider,
modelName,
displayName,
isPreferred,
isHidden,
isNonPreferredInSelection,
onSelect,
onToggleVisibility,
agentMessageProps,
}: MultiModelPanelProps) {
const ProviderIcon = getProviderIcon(provider, modelName);
const handlePanelClick = useCallback(() => {
if (!isHidden) {
onSelect();
}
}, [isHidden, onSelect]);
// Hidden/collapsed panel — compact strip: icon + strikethrough name + eye icon
if (isHidden) {
return (
<div className="flex items-center gap-1.5 rounded-08 bg-background-tint-00 px-2 py-1 opacity-50">
<div className="flex items-center justify-center size-5 shrink-0">
<ProviderIcon size={16} />
</div>
<Text secondaryBody text02 nowrap className="line-through">
{displayName}
</Text>
<Button
prominence="tertiary"
icon={SvgEyeClosed}
size="2xs"
onClick={onToggleVisibility}
tooltip="Show response"
/>
</div>
);
}
const hoverGroup = `panel-${modelIndex}`;
return (
<Hoverable.Root group={hoverGroup}>
<div
className="flex flex-col min-w-0 gap-3 cursor-pointer"
onClick={handlePanelClick}
>
{/* Panel header */}
<div
className={cn(
"flex items-center gap-1.5 rounded-12 px-2 py-1",
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
)}
>
<div className="flex items-center justify-center size-5 shrink-0">
<ProviderIcon size={16} />
</div>
<Text mainUiAction text04 nowrap className="flex-1 min-w-0 truncate">
{displayName}
</Text>
{isPreferred && (
<Text secondaryBody nowrap className="text-action-link-05 shrink-0">
Preferred Response
</Text>
)}
<Button
prominence="tertiary"
icon={SvgMoreHorizontal}
size="2xs"
tooltip="More"
onClick={(e) => e.stopPropagation()}
/>
<Button
prominence="tertiary"
icon={SvgX}
size="2xs"
onClick={(e) => {
e.stopPropagation();
onToggleVisibility();
}}
tooltip="Hide response"
/>
</div>
{/* "Select This Response" hover affordance */}
{!isPreferred && !isNonPreferredInSelection && (
<Hoverable.Item group={hoverGroup} variant="opacity-on-hover">
<div className="flex justify-center pointer-events-none">
<div className="flex items-center h-6 bg-background-tint-00 rounded-08 px-1 shadow-sm">
<Text
secondaryBody
className="font-semibold text-text-03 px-1 whitespace-nowrap"
>
Select This Response
</Text>
</div>
</div>
</Hoverable.Item>
)}
{/* Response body */}
<div className={cn(isNonPreferredInSelection && "pointer-events-none")}>
<AgentMessage
{...agentMessageProps}
hideFooter={isNonPreferredInSelection}
/>
</div>
</div>
</Hoverable.Root>
);
}

View File

@@ -0,0 +1,229 @@
"use client";
import { useState, useCallback, useMemo } from "react";
import { Packet } from "@/app/app/services/streamingModels";
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
import { FeedbackType, Message } from "@/app/app/interfaces";
import { LlmManager } from "@/lib/hooks";
import { RegenerationFactory } from "@/app/app/message/messageComponents/AgentMessage";
import MultiModelPanel from "@/app/app/message/MultiModelPanel";
import { cn } from "@/lib/utils";
export interface MultiModelResponse {
modelIndex: number;
provider: string;
modelName: string;
displayName: string;
packets: Packet[];
packetCount: number;
nodeId: number;
messageId?: number;
isHighlighted?: boolean;
currentFeedback?: FeedbackType | null;
isGenerating?: boolean;
}
export interface MultiModelResponseViewProps {
responses: MultiModelResponse[];
chatState: FullChatState;
llmManager: LlmManager | null;
onRegenerate?: RegenerationFactory;
parentMessage?: Message | null;
otherMessagesCanSwitchTo?: number[];
onMessageSelection?: (nodeId: number) => void;
}
export default function MultiModelResponseView({
responses,
chatState,
llmManager,
onRegenerate,
parentMessage,
otherMessagesCanSwitchTo,
onMessageSelection,
}: MultiModelResponseViewProps) {
const [preferredIndex, setPreferredIndex] = useState<number | null>(null);
const [hiddenPanels, setHiddenPanels] = useState<Set<number>>(new Set());
const isGenerating = useMemo(
() => responses.some((r) => r.isGenerating),
[responses]
);
const visibleResponses = useMemo(
() => responses.filter((r) => !hiddenPanels.has(r.modelIndex)),
[responses, hiddenPanels]
);
const hiddenResponses = useMemo(
() => responses.filter((r) => hiddenPanels.has(r.modelIndex)),
[responses, hiddenPanels]
);
const toggleVisibility = useCallback(
(modelIndex: number) => {
setHiddenPanels((prev) => {
const next = new Set(prev);
if (next.has(modelIndex)) {
next.delete(modelIndex);
} else {
// Don't hide the last visible panel
const visibleCount = responses.length - next.size;
if (visibleCount <= 1) return prev;
next.add(modelIndex);
}
return next;
});
},
[responses.length]
);
const handleSelectPreferred = useCallback(
(modelIndex: number) => {
setPreferredIndex(modelIndex);
const response = responses[modelIndex];
if (!response) return;
// Sync with message tree — mark this response as the latest child
// so the next message chains from it.
if (onMessageSelection) {
onMessageSelection(response.nodeId);
}
},
[responses, onMessageSelection]
);
// Selection mode when preferred is set and not generating
const showSelectionMode =
preferredIndex !== null && !isGenerating && visibleResponses.length > 1;
// Build common panel props
const buildPanelProps = useCallback(
(response: MultiModelResponse, isNonPreferred: boolean) => ({
modelIndex: response.modelIndex,
provider: response.provider,
modelName: response.modelName,
displayName: response.displayName,
isPreferred: preferredIndex === response.modelIndex,
isHidden: false as const,
isNonPreferredInSelection: isNonPreferred,
onSelect: () => handleSelectPreferred(response.modelIndex),
onToggleVisibility: () => toggleVisibility(response.modelIndex),
agentMessageProps: {
rawPackets: response.packets,
packetCount: response.packetCount,
chatState,
nodeId: response.nodeId,
messageId: response.messageId,
currentFeedback: response.currentFeedback,
llmManager,
otherMessagesCanSwitchTo,
onMessageSelection,
onRegenerate,
parentMessage,
},
}),
[
preferredIndex,
handleSelectPreferred,
toggleVisibility,
chatState,
llmManager,
otherMessagesCanSwitchTo,
onMessageSelection,
onRegenerate,
parentMessage,
]
);
// Shared renderer for hidden panels (inline in the flex row)
const renderHiddenPanels = () =>
hiddenResponses.map((r) => (
<div key={r.modelIndex} className="w-[240px] shrink-0">
<MultiModelPanel
modelIndex={r.modelIndex}
provider={r.provider}
modelName={r.modelName}
displayName={r.displayName}
isPreferred={false}
isHidden
isNonPreferredInSelection={false}
onSelect={() => handleSelectPreferred(r.modelIndex)}
onToggleVisibility={() => toggleVisibility(r.modelIndex)}
agentMessageProps={buildPanelProps(r, false).agentMessageProps}
/>
</div>
));
if (showSelectionMode) {
// ── Selection Layout ──
// Preferred stays at normal chat width, centered.
// Non-preferred panels are pushed to the viewport edges and clip off-screen.
const preferredIdx = visibleResponses.findIndex(
(r) => r.modelIndex === preferredIndex
);
const preferred = visibleResponses[preferredIdx];
const leftPanels = visibleResponses.slice(0, preferredIdx);
const rightPanels = visibleResponses.slice(preferredIdx + 1);
// Non-preferred panel width and gap between panels
const PANEL_W = 400;
const GAP = 16;
return (
<div className="w-full relative overflow-hidden">
{/* Preferred — centered at normal chat width, in flow to set container height */}
{preferred && (
<div className="w-full max-w-[720px] min-w-[400px] mx-auto">
<MultiModelPanel {...buildPanelProps(preferred, false)} />
</div>
)}
{/* Non-preferred on the left — anchored to the left of the preferred panel */}
{leftPanels.map((r, i) => (
<div
key={r.modelIndex}
className="absolute top-0"
style={{
width: `${PANEL_W}px`,
// Right edge of this panel sits just left of the preferred panel
right: `calc(50% + var(--app-page-main-content-width) / 2 + ${
GAP + i * (PANEL_W + GAP)
}px)`,
}}
>
<MultiModelPanel {...buildPanelProps(r, true)} />
</div>
))}
{/* Non-preferred on the right — anchored to the right of the preferred panel */}
{rightPanels.map((r, i) => (
<div
key={r.modelIndex}
className="absolute top-0"
style={{
width: `${PANEL_W}px`,
// Left edge of this panel sits just right of the preferred panel
left: `calc(50% + var(--app-page-main-content-width) / 2 + ${
GAP + i * (PANEL_W + GAP)
}px)`,
}}
>
<MultiModelPanel {...buildPanelProps(r, true)} />
</div>
))}
</div>
);
}
// ── Generation Layout (equal panels) ──
return (
<div className="flex gap-6 items-start justify-center">
{visibleResponses.map((r) => (
<div key={r.modelIndex} className="flex-1 min-w-[400px] max-w-[720px]">
<MultiModelPanel {...buildPanelProps(r, false)} />
</div>
))}
{renderHiddenPanels()}
</div>
);
}

View File

@@ -49,6 +49,8 @@ export interface AgentMessageProps {
parentMessage?: Message | null;
// Duration in seconds for processing this message (agent messages only)
processingDurationSeconds?: number;
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
hideFooter?: boolean;
}
// TODO: Consider more robust comparisons:
@@ -76,7 +78,8 @@ function arePropsEqual(
prev.parentMessage?.messageId === next.parentMessage?.messageId &&
prev.llmManager?.isLoadingProviders ===
next.llmManager?.isLoadingProviders &&
prev.processingDurationSeconds === next.processingDurationSeconds
prev.processingDurationSeconds === next.processingDurationSeconds &&
prev.hideFooter === next.hideFooter
// Skip: chatState.regenerate, chatState.setPresentingDocument,
// most of llmManager, onMessageSelection (function/object props)
);
@@ -95,6 +98,7 @@ const AgentMessage = React.memo(function AgentMessage({
onRegenerate,
parentMessage,
processingDurationSeconds,
hideFooter,
}: AgentMessageProps) {
const markdownRef = useRef<HTMLDivElement>(null);
const finalAnswerRef = useRef<HTMLDivElement>(null);
@@ -326,7 +330,7 @@ const AgentMessage = React.memo(function AgentMessage({
</div>
{/* Feedback buttons - only show when streaming and rendering complete */}
{isComplete && (
{isComplete && !hideFooter && (
<MessageToolbar
nodeId={nodeId}
messageId={messageId}

View File

@@ -12,6 +12,7 @@ import {
FileChatDisplay,
Message,
MessageResponseIDInfo,
MultiModelMessageResponseIDInfo,
ResearchType,
RetrievalType,
StreamingError,
@@ -96,6 +97,7 @@ export type PacketType =
| FileChatDisplay
| StreamingError
| MessageResponseIDInfo
| MultiModelMessageResponseIDInfo
| StreamStopInfo
| UserKnowledgeFilePacket
| Packet;
@@ -109,6 +111,13 @@ export type MessageOrigin =
| "slackbot"
| "unknown";
export interface LLMOverride {
model_provider: string;
model_version: string;
temperature?: number;
display_name?: string;
}
export interface SendMessageParams {
message: string;
fileDescriptors?: FileDescriptor[];
@@ -124,6 +133,8 @@ export interface SendMessageParams {
modelProvider?: string;
modelVersion?: string;
temperature?: number;
// Multi-model: send multiple LLM overrides for parallel generation
llmOverrides?: LLMOverride[];
// Origin of the message for telemetry tracking
origin?: MessageOrigin;
// Additional context injected into the LLM call but not stored/shown in chat.
@@ -144,6 +155,7 @@ export async function* sendMessage({
modelProvider,
modelVersion,
temperature,
llmOverrides,
origin,
additionalContext,
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
@@ -165,6 +177,8 @@ export async function* sendMessage({
model_version: modelVersion,
}
: null,
// Multi-model: list of LLM overrides for parallel generation
llm_overrides: llmOverrides ?? null,
// Default to "unknown" for consistency with backend; callers should set explicitly
origin: origin ?? "unknown",
additional_context: additionalContext ?? null,
@@ -188,6 +202,20 @@ export async function* sendMessage({
yield* handleSSEStream<PacketType>(response, signal);
}
export async function setPreferredResponse(
userMessageId: number,
preferredResponseId: number
): Promise<Response> {
return fetch("/api/chat/set-preferred-response", {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
user_message_id: userMessageId,
preferred_response_id: preferredResponseId,
}),
});
}
export async function nameChatSession(chatSessionId: string) {
const response = await fetch("/api/chat/rename-chat-session", {
method: "PUT",
@@ -357,6 +385,9 @@ export function processRawChatHistory(
overridden_model: messageInfo.overridden_model,
packets: packetsForMessage || [],
currentFeedback: messageInfo.current_feedback as FeedbackType | null,
// Multi-model answer generation
preferredResponseId: messageInfo.preferred_response_id ?? null,
modelDisplayName: messageInfo.model_display_name ?? null,
};
messages.set(messageInfo.message_id, message);

View File

@@ -403,6 +403,7 @@ export interface Placement {
turn_index: number;
tab_index?: number; // For parallel tool calls - tools with same turn_index but different tab_index run in parallel
sub_turn_index?: number | null;
model_index?: number | null; // For multi-model answer generation - identifies which model produced this packet
}
// Packet wrapper for streaming objects

View File

@@ -0,0 +1,232 @@
import { renderHook, act } from "@testing-library/react";
import useMultiModelChat from "@/hooks/useMultiModelChat";
import { LlmManager } from "@/lib/hooks";
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
// Mock buildLlmOptions — hook uses it internally for initialization.
// Tests here focus on CRUD operations, not the initialization side-effect.
jest.mock("@/refresh-components/popovers/LLMPopover", () => ({
buildLlmOptions: jest.fn(() => []),
}));
const makeLlmManager = (): LlmManager =>
({
llmProviders: [],
currentLlm: { modelName: null, provider: null },
isLoadingProviders: false,
}) as unknown as LlmManager;
const makeModel = (provider: string, modelName: string): SelectedModel => ({
name: provider,
provider,
modelName,
displayName: `${provider}/${modelName}`,
});
const GPT4 = makeModel("openai", "gpt-4");
const CLAUDE = makeModel("anthropic", "claude-opus-4-6");
const GEMINI = makeModel("google", "gemini-pro");
const GPT4_TURBO = makeModel("openai", "gpt-4-turbo");
// ---------------------------------------------------------------------------
// addModel
// ---------------------------------------------------------------------------
describe("addModel", () => {
it("adds a model to an empty selection", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
expect(result.current.selectedModels).toHaveLength(1);
expect(result.current.selectedModels[0]).toEqual(GPT4);
});
it("does not add a duplicate model", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(GPT4); // duplicate
});
expect(result.current.selectedModels).toHaveLength(1);
});
it("enforces MAX_MODELS (3) cap", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
result.current.addModel(GEMINI);
result.current.addModel(GPT4_TURBO); // should be ignored
});
expect(result.current.selectedModels).toHaveLength(3);
});
});
// ---------------------------------------------------------------------------
// removeModel
// ---------------------------------------------------------------------------
describe("removeModel", () => {
it("removes a model by index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.removeModel(0); // remove GPT4
});
expect(result.current.selectedModels).toHaveLength(1);
expect(result.current.selectedModels[0]).toEqual(CLAUDE);
});
it("handles out-of-range index gracefully", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
act(() => {
result.current.removeModel(99); // no-op
});
expect(result.current.selectedModels).toHaveLength(1);
});
});
// ---------------------------------------------------------------------------
// replaceModel
// ---------------------------------------------------------------------------
describe("replaceModel", () => {
it("replaces the model at the given index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.replaceModel(0, GEMINI);
});
expect(result.current.selectedModels[0]).toEqual(GEMINI);
expect(result.current.selectedModels[1]).toEqual(CLAUDE);
});
it("does not replace with a model already selected at another index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.replaceModel(0, CLAUDE); // CLAUDE is already at index 1
});
// Should be a no-op — GPT4 stays at index 0
expect(result.current.selectedModels[0]).toEqual(GPT4);
});
});
// ---------------------------------------------------------------------------
// isMultiModelActive
// ---------------------------------------------------------------------------
describe("isMultiModelActive", () => {
it("is false with zero models", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
expect(result.current.isMultiModelActive).toBe(false);
});
it("is false with exactly one model", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
expect(result.current.isMultiModelActive).toBe(false);
});
it("is true with two or more models", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
expect(result.current.isMultiModelActive).toBe(true);
});
});
// ---------------------------------------------------------------------------
// buildLlmOverrides
// ---------------------------------------------------------------------------
describe("buildLlmOverrides", () => {
it("returns empty array when no models selected", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
expect(result.current.buildLlmOverrides()).toEqual([]);
});
it("maps selectedModels to LLMOverride format", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
const overrides = result.current.buildLlmOverrides();
expect(overrides).toHaveLength(2);
expect(overrides[0]).toEqual({
model_provider: "openai",
model_version: "gpt-4",
display_name: "openai/gpt-4",
});
expect(overrides[1]).toEqual({
model_provider: "anthropic",
model_version: "claude-opus-4-6",
display_name: "anthropic/claude-opus-4-6",
});
});
});
// ---------------------------------------------------------------------------
// clearModels
// ---------------------------------------------------------------------------
describe("clearModels", () => {
it("empties the selection", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.clearModels();
});
expect(result.current.selectedModels).toHaveLength(0);
expect(result.current.isMultiModelActive).toBe(false);
});
});

View File

@@ -0,0 +1,191 @@
"use client";
import { useState, useCallback, useEffect, useMemo } from "react";
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
import { LLMOverride } from "@/app/app/services/lib";
import { LlmManager } from "@/lib/hooks";
import { buildLlmOptions } from "@/refresh-components/popovers/LLMPopover";
const MAX_MODELS = 3;
export interface UseMultiModelChatReturn {
/** Currently selected models for multi-model comparison. */
selectedModels: SelectedModel[];
/** Whether multi-model mode is active (>1 model selected). */
isMultiModelActive: boolean;
/** Add a model to the selection. */
addModel: (model: SelectedModel) => void;
/** Remove a model by index. */
removeModel: (index: number) => void;
/** Replace a model at a specific index with a new one. */
replaceModel: (index: number, model: SelectedModel) => void;
/** Clear all selected models. */
clearModels: () => void;
/** Build the LLMOverride[] array from selectedModels. */
buildLlmOverrides: () => LLMOverride[];
/**
* Restore multi-model selection from model version strings (e.g. from chat history).
* Matches against available llmOptions to reconstruct full SelectedModel objects.
*/
restoreFromModelNames: (modelNames: string[]) => void;
/**
* Switch to a single model by name (after user picks a preferred response).
* Matches against llmOptions to find the full SelectedModel.
*/
selectSingleModel: (modelName: string) => void;
}
export default function useMultiModelChat(
llmManager: LlmManager
): UseMultiModelChatReturn {
const [selectedModels, setSelectedModels] = useState<SelectedModel[]>([]);
const [defaultInitialized, setDefaultInitialized] = useState(false);
// Initialize with the default model from llmManager once providers load
const llmOptions = useMemo(
() =>
llmManager.llmProviders ? buildLlmOptions(llmManager.llmProviders) : [],
[llmManager.llmProviders]
);
useEffect(() => {
if (defaultInitialized) return;
if (llmOptions.length === 0) return;
const { currentLlm } = llmManager;
// Don't initialize if currentLlm hasn't loaded yet
if (!currentLlm.modelName) return;
const match = llmOptions.find(
(opt) =>
opt.provider === currentLlm.provider &&
opt.modelName === currentLlm.modelName
);
if (match) {
setSelectedModels([
{
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
},
]);
setDefaultInitialized(true);
}
}, [llmOptions, llmManager.currentLlm, defaultInitialized, llmManager]);
const isMultiModelActive = selectedModels.length > 1;
const addModel = useCallback((model: SelectedModel) => {
setSelectedModels((prev) => {
if (prev.length >= MAX_MODELS) return prev;
if (
prev.some(
(m) =>
m.provider === model.provider && m.modelName === model.modelName
)
) {
return prev;
}
return [...prev, model];
});
}, []);
const removeModel = useCallback((index: number) => {
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
}, []);
const replaceModel = useCallback((index: number, model: SelectedModel) => {
setSelectedModels((prev) => {
// Don't replace with a model that's already selected elsewhere
if (
prev.some(
(m, i) =>
i !== index &&
m.provider === model.provider &&
m.modelName === model.modelName
)
) {
return prev;
}
const next = [...prev];
next[index] = model;
return next;
});
}, []);
const clearModels = useCallback(() => {
setSelectedModels([]);
}, []);
const restoreFromModelNames = useCallback(
(modelNames: string[]) => {
if (modelNames.length < 2 || llmOptions.length === 0) return;
const restored: SelectedModel[] = [];
for (const name of modelNames) {
// Try matching by modelName (raw version string like "claude-opus-4-6")
// or by displayName (friendly name like "Claude Opus 4.6")
const match = llmOptions.find(
(opt) =>
opt.modelName === name ||
opt.displayName === name ||
opt.name === name
);
if (match) {
restored.push({
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
});
}
}
if (restored.length >= 2) {
setSelectedModels(restored);
setDefaultInitialized(true);
}
},
[llmOptions]
);
const selectSingleModel = useCallback(
(modelName: string) => {
if (llmOptions.length === 0) return;
const match = llmOptions.find(
(opt) =>
opt.modelName === modelName ||
opt.displayName === modelName ||
opt.name === modelName
);
if (match) {
setSelectedModels([
{
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
},
]);
}
},
[llmOptions]
);
const buildLlmOverrides = useCallback((): LLMOverride[] => {
return selectedModels.map((m) => ({
model_provider: m.name,
model_version: m.modelName,
display_name: m.displayName,
}));
}, [selectedModels]);
return {
selectedModels,
isMultiModelActive,
addModel,
removeModel,
replaceModel,
clearModels,
buildLlmOverrides,
restoreFromModelNames,
selectSingleModel,
};
}

View File

@@ -32,6 +32,12 @@ export interface FrostedDivProps extends React.HTMLAttributes<HTMLDivElement> {
* Additional classes for the frost overlay element itself
*/
overlayClassName?: string;
/**
* Additional classes for the outermost wrapper div (the `relative` container).
* Useful for width propagation (e.g. `w-full`).
*/
wrapperClassName?: string;
}
/**
@@ -60,13 +66,14 @@ export default function FrostedDiv({
backdropBlur = "6px",
borderRadius = "1rem",
overlayClassName,
wrapperClassName,
className,
style,
children,
...props
}: FrostedDivProps) {
return (
<div className="relative">
<div className={cn("relative", wrapperClassName)}>
{/* Frost effect overlay - positioned behind content with bloom extending outward */}
<div
className={cn("absolute pointer-events-none", overlayClassName)}

View File

@@ -0,0 +1,469 @@
"use client";
import { useState, useMemo, useRef } from "react";
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
import { LlmManager } from "@/lib/hooks";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import Text from "@/refresh-components/texts/Text";
import { Button } from "@opal/components";
import {
SvgCheck,
SvgChevronDown,
SvgChevronRight,
SvgColumn,
SvgPlusCircle,
SvgX,
} from "@opal/icons";
import { Section } from "@/layouts/general-layouts";
import {
LLMOption,
LLMOptionGroup,
} from "@/refresh-components/popovers/interfaces";
import {
buildLlmOptions,
groupLlmOptions,
} from "@/refresh-components/popovers/LLMPopover";
import {
Accordion,
AccordionContent,
AccordionItem,
AccordionTrigger,
} from "@/components/ui/accordion";
import { cn } from "@/lib/utils";
const MAX_MODELS = 3;
export interface SelectedModel {
name: string;
provider: string;
modelName: string;
displayName: string;
}
export interface ModelSelectorProps {
llmManager: LlmManager;
selectedModels: SelectedModel[];
onAdd: (model: SelectedModel) => void;
onRemove: (index: number) => void;
onReplace: (index: number, model: SelectedModel) => void;
}
/** Vertical 1px divider between model bar elements */
function BarDivider() {
return <div className="h-9 w-px bg-border-01 shrink-0" />;
}
/** Individual model pill in the model bar */
function ModelPill({
model,
isMultiModel,
onRemove,
onClick,
}: {
model: SelectedModel;
isMultiModel: boolean;
onRemove?: () => void;
onClick?: () => void;
}) {
const ProviderIcon = getProviderIcon(model.provider, model.modelName);
return (
<div
role="button"
tabIndex={0}
onClick={onClick}
onKeyDown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
onClick?.();
}
}}
className={cn(
"flex items-center gap-0.5 rounded-12 p-2 shrink-0 cursor-pointer",
"hover:bg-background-tint-02 transition-colors",
isMultiModel && "bg-background-tint-02"
)}
>
<div className="flex items-center justify-center size-5 shrink-0 p-0.5">
<ProviderIcon size={16} />
</div>
<Text mainUiAction text04 nowrap className="px-1">
{model.displayName}
</Text>
{isMultiModel ? (
<button
type="button"
onClick={(e) => {
e.stopPropagation();
onRemove?.();
}}
className="flex items-center justify-center size-4 shrink-0 hover:opacity-70"
>
<SvgX className="size-4 stroke-text-03" />
</button>
) : (
<SvgChevronDown className="size-4 stroke-text-03 shrink-0" />
)}
</div>
);
}
/** Model item row inside the add-model popover */
function ModelItem({
option,
isSelected,
isDisabled,
onToggle,
}: {
option: LLMOption;
isSelected: boolean;
isDisabled: boolean;
onToggle: () => void;
}) {
const ProviderIcon = getProviderIcon(option.provider, option.modelName);
// Build subtitle from model capabilities
const subtitle = useMemo(() => {
const parts: string[] = [];
if (option.supportsReasoning) parts.push("reasoning");
if (option.supportsImageInput) parts.push("multi-modal");
if (parts.length === 0 && option.modelName) return option.modelName;
return parts.join(", ");
}, [option]);
return (
<button
type="button"
disabled={isDisabled}
onClick={onToggle}
className={cn(
"flex items-center gap-1.5 w-full rounded-08 p-1.5 text-left transition-colors",
isSelected ? "bg-action-link-01" : "hover:bg-background-tint-02",
isDisabled && !isSelected && "opacity-50 cursor-not-allowed"
)}
>
<div className="flex items-center justify-center size-5 shrink-0 p-0.5">
<ProviderIcon size={16} />
</div>
<div className="flex flex-col flex-1 min-w-0">
<Text
mainUiAction
nowrap
className={cn(isSelected ? "text-action-link-03" : "text-text-04")}
>
{option.displayName}
</Text>
{subtitle && (
<Text secondaryBody text03 nowrap>
{subtitle}
</Text>
)}
</div>
{isSelected && (
<Text secondaryBody nowrap className="text-action-link-05 shrink-0">
Added
</Text>
)}
</button>
);
}
export default function ModelSelector({
llmManager,
selectedModels,
onAdd,
onRemove,
onReplace,
}: ModelSelectorProps) {
const [open, setOpen] = useState(false);
const [searchQuery, setSearchQuery] = useState("");
const scrollContainerRef = useRef<HTMLDivElement>(null);
// null = add mode (via + button), number = replace mode (via pill click)
const [replacingIndex, setReplacingIndex] = useState<number | null>(null);
const isMultiModel = selectedModels.length > 1;
const atMax = selectedModels.length >= MAX_MODELS;
const llmOptions = useMemo(
() => buildLlmOptions(llmManager.llmProviders),
[llmManager.llmProviders]
);
const selectedKeys = useMemo(
() => new Set(selectedModels.map((m) => `${m.provider}:${m.modelName}`)),
[selectedModels]
);
const filteredOptions = useMemo(() => {
if (!searchQuery.trim()) return llmOptions;
const query = searchQuery.toLowerCase();
return llmOptions.filter(
(opt) =>
opt.displayName.toLowerCase().includes(query) ||
opt.modelName.toLowerCase().includes(query) ||
(opt.vendor && opt.vendor.toLowerCase().includes(query))
);
}, [llmOptions, searchQuery]);
const groupedOptions = useMemo(
() => groupLlmOptions(filteredOptions),
[filteredOptions]
);
const isSearching = searchQuery.trim().length > 0;
// In replace mode, other selected models (not the one being replaced) are disabled
const otherSelectedKeys = useMemo(() => {
if (replacingIndex === null) return new Set<string>();
return new Set(
selectedModels
.filter((_, i) => i !== replacingIndex)
.map((m) => `${m.provider}:${m.modelName}`)
);
}, [selectedModels, replacingIndex]);
// Current model at the replacing index (shows as "selected" in replace mode)
const replacingKey = useMemo(() => {
if (replacingIndex === null) return null;
const m = selectedModels[replacingIndex];
return m ? `${m.provider}:${m.modelName}` : null;
}, [selectedModels, replacingIndex]);
const getItemState = (optKey: string) => {
if (replacingIndex !== null) {
// Replace mode
return {
isSelected: optKey === replacingKey,
isDisabled: otherSelectedKeys.has(optKey),
};
}
// Add mode
return {
isSelected: selectedKeys.has(optKey),
isDisabled: !selectedKeys.has(optKey) && atMax,
};
};
const handleSelectModel = (option: LLMOption) => {
const model: SelectedModel = {
name: option.name,
provider: option.provider,
modelName: option.modelName,
displayName: option.displayName,
};
if (replacingIndex !== null) {
// Replace mode: swap the model at the clicked pill's index
onReplace(replacingIndex, model);
setOpen(false);
setReplacingIndex(null);
setSearchQuery("");
return;
}
// Add mode: toggle (add/remove)
const key = `${option.provider}:${option.modelName}`;
const existingIndex = selectedModels.findIndex(
(m) => `${m.provider}:${m.modelName}` === key
);
if (existingIndex >= 0) {
onRemove(existingIndex);
} else if (!atMax) {
onAdd(model);
}
};
const handleOpenChange = (nextOpen: boolean) => {
setOpen(nextOpen);
if (!nextOpen) {
setReplacingIndex(null);
setSearchQuery("");
}
};
const handlePillClick = (index: number) => {
setReplacingIndex(index);
setOpen(true);
};
return (
<div className="flex items-center justify-end gap-1 p-1">
{/* (+) Add model button — hidden at max models */}
{!atMax && (
<Popover open={open} onOpenChange={handleOpenChange}>
<Popover.Trigger asChild>
<Button
prominence="tertiary"
icon={SvgPlusCircle}
size="sm"
tooltip="Add Model"
/>
</Popover.Trigger>
<Popover.Content side="top" align="start" width="lg">
<Section gap={0.25}>
<InputTypeIn
leftSearchIcon
variant="internal"
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
placeholder="Search models..."
/>
<PopoverMenu scrollContainerRef={scrollContainerRef}>
{groupedOptions.length === 0
? [
<div key="empty" className="py-3 px-2">
<Text secondaryBody text03>
No models found
</Text>
</div>,
]
: groupedOptions.length === 1
? [
<div key="single" className="flex flex-col gap-0.5">
{groupedOptions[0]!.options.map((opt) => {
const key = `${opt.provider}:${opt.modelName}`;
const state = getItemState(key);
return (
<ModelItem
key={opt.modelName}
option={opt}
isSelected={state.isSelected}
isDisabled={state.isDisabled}
onToggle={() => handleSelectModel(opt)}
/>
);
})}
</div>,
]
: [
<ModelGroupAccordion
key="accordion"
groups={groupedOptions}
isSearching={isSearching}
getItemState={getItemState}
onToggle={handleSelectModel}
/>,
]}
</PopoverMenu>
<div className="border-t border-border-01 mt-1 pt-1">
<button
type="button"
className="flex items-center gap-1.5 w-full rounded-08 p-1.5 text-left hover:bg-background-tint-02 transition-colors"
>
<SvgColumn className="size-5 stroke-text-03 shrink-0" />
<Text mainUiAction text04>
Compare Model
</Text>
</button>
</div>
</Section>
</Popover.Content>
</Popover>
)}
{/* Divider + model pills */}
{selectedModels.length > 0 && (
<>
<BarDivider />
{selectedModels.map((model, index) => (
<div
key={`${model.provider}:${model.modelName}`}
className="flex items-center gap-1"
>
{index > 0 && <BarDivider />}
<ModelPill
model={model}
isMultiModel={isMultiModel}
onRemove={() => onRemove(index)}
onClick={() => handlePillClick(index)}
/>
</div>
))}
</>
)}
</div>
);
}
interface ModelGroupAccordionProps {
groups: LLMOptionGroup[];
isSearching: boolean;
getItemState: (key: string) => { isSelected: boolean; isDisabled: boolean };
onToggle: (option: LLMOption) => void;
}
function ModelGroupAccordion({
groups,
isSearching,
getItemState,
onToggle,
}: ModelGroupAccordionProps) {
const allKeys = groups.map((g) => g.key);
const [expandedGroups, setExpandedGroups] = useState<string[]>([
allKeys[0] ?? "",
]);
const effectiveExpanded = isSearching ? allKeys : expandedGroups;
return (
<Accordion
type="multiple"
value={effectiveExpanded}
onValueChange={(value) => {
if (!isSearching) setExpandedGroups(value);
}}
className="w-full flex flex-col"
>
{groups.map((group) => {
const isExpanded = effectiveExpanded.includes(group.key);
return (
<AccordionItem
key={group.key}
value={group.key}
className="border-none pt-1"
>
<AccordionTrigger className="flex items-center rounded-08 hover:no-underline hover:bg-background-tint-02 group [&>svg]:hidden w-full py-1">
<div className="flex items-center gap-1 shrink-0">
<div className="flex items-center justify-center size-5 shrink-0">
<group.Icon size={16} />
</div>
<Text secondaryBody text03 nowrap className="px-0.5">
{group.displayName}
</Text>
</div>
<div className="flex-1" />
<div className="flex items-center justify-center size-6 shrink-0">
{isExpanded ? (
<SvgChevronDown className="h-4 w-4 stroke-text-04 shrink-0" />
) : (
<SvgChevronRight className="h-4 w-4 stroke-text-04 shrink-0" />
)}
</div>
</AccordionTrigger>
<AccordionContent className="pb-0 pt-0">
<div className="flex flex-col gap-0.5">
{group.options.map((opt) => {
const key = `${opt.provider}:${opt.modelName}`;
const state = getItemState(key);
return (
<ModelItem
key={key}
option={opt}
isSelected={state.isSelected}
isDisabled={state.isDisabled}
onToggle={() => onToggle(opt)}
/>
);
})}
</div>
</AccordionContent>
</AccordionItem>
);
})}
</Accordion>
);
}