mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-28 11:02:42 +00:00
Compare commits
5 Commits
cli/v0.2.0
...
multi-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf403d9c89 | ||
|
|
ce7e68f671 | ||
|
|
65c5d5d5d9 | ||
|
|
ebe558e04f | ||
|
|
a49edf3e18 |
@@ -4,9 +4,11 @@ An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import io
|
||||
import queue
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
@@ -28,6 +30,7 @@ from onyx.chat.compression import calculate_total_history_tokens
|
||||
from onyx.chat.compression import compress_chat_history
|
||||
from onyx.chat.compression import find_summary_for_branch
|
||||
from onyx.chat.compression import get_compression_params
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.llm_loop import EmptyLLMResponseError
|
||||
from onyx.chat.llm_loop import run_llm_loop
|
||||
@@ -59,6 +62,8 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -86,16 +91,21 @@ from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import ModelResponseSlot
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
@@ -1069,6 +1079,583 @@ def handle_stream_message_objects(
|
||||
logger.exception("Error in setting processing status")
|
||||
|
||||
|
||||
def _build_model_display_name(override: LLMOverride) -> str:
|
||||
"""Build a human-readable display name from an LLM override."""
|
||||
if override.display_name:
|
||||
return override.display_name
|
||||
if override.model_version:
|
||||
return override.model_version
|
||||
if override.model_provider:
|
||||
return override.model_provider
|
||||
return "unknown"
|
||||
|
||||
|
||||
# Sentinel placed on the merged queue when a model thread finishes.
|
||||
_MODEL_DONE = object()
|
||||
|
||||
|
||||
class _ModelIndexEmitter(Emitter):
|
||||
"""Emitter that tags packets with model_index and forwards directly to a shared queue.
|
||||
|
||||
Unlike the standard Emitter (which accumulates in a local bus), this puts
|
||||
packets into the shared merged_queue in real-time as they're emitted. This
|
||||
enables true parallel streaming — packets from multiple models interleave
|
||||
on the wire instead of arriving in bursts after each model completes.
|
||||
"""
|
||||
|
||||
def __init__(self, model_idx: int, merged_queue: queue.Queue) -> None:
|
||||
super().__init__(queue.Queue()) # bus exists for compat, unused
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
tagged_placement = Placement(
|
||||
turn_index=packet.placement.turn_index if packet.placement else 0,
|
||||
tab_index=packet.placement.tab_index if packet.placement else 0,
|
||||
sub_turn_index=(
|
||||
packet.placement.sub_turn_index if packet.placement else None
|
||||
),
|
||||
model_index=self._model_idx,
|
||||
)
|
||||
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
|
||||
self._merged_queue.put((self._model_idx, tagged_packet))
|
||||
|
||||
|
||||
def run_multi_model_stream(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
llm_overrides: list[LLMOverride],
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
) -> AnswerStream:
|
||||
# TODO(ENG-3888): The setup logic below (session resolution through tool construction)
|
||||
# is duplicated from handle_stream_message_objects. Extract into a shared
|
||||
# _ChatStreamContext dataclass + _prepare_chat_stream_context() factory so
|
||||
# both paths call the same setup code.
|
||||
# https://linear.app/onyx-app/issue/ENG-3888
|
||||
"""Run 2-3 LLMs in parallel and yield their packets tagged with model_index.
|
||||
|
||||
Resource management:
|
||||
- Each model thread gets its OWN db_session (SQLAlchemy sessions are not thread-safe)
|
||||
- The caller's db_session is used only for setup (before threads launch) and
|
||||
completion callbacks (after threads finish)
|
||||
- ThreadPoolExecutor is bounded to len(overrides) workers
|
||||
- All threads are joined in the finally block regardless of success/failure
|
||||
- Queue-based merging avoids busy-waiting
|
||||
"""
|
||||
n_models = len(llm_overrides)
|
||||
if n_models < 2 or n_models > 3:
|
||||
raise ValueError(f"Multi-model requires 2-3 overrides, got {n_models}")
|
||||
if new_msg_req.deep_research:
|
||||
raise ValueError("Multi-model is not supported with deep research")
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
cache: CacheBackend | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
|
||||
user_id = user.id
|
||||
if user.is_anonymous:
|
||||
llm_user_identifier = "anonymous_user"
|
||||
else:
|
||||
llm_user_identifier = user.email or str(user_id)
|
||||
|
||||
try:
|
||||
# ── Session setup (same as single-model path) ──────────────────
|
||||
if not new_msg_req.chat_session_id:
|
||||
if not new_msg_req.chat_session_info:
|
||||
raise RuntimeError(
|
||||
"Must specify a chat session id or chat session info"
|
||||
)
|
||||
chat_session = create_chat_session_from_request(
|
||||
chat_session_request=new_msg_req.chat_session_info,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield CreateChatSessionID(chat_session_id=chat_session.id)
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
persona = chat_session.persona
|
||||
message_text = new_msg_req.message
|
||||
|
||||
# ── Build N LLM instances and validate costs ───────────────────
|
||||
llms: list[LLM] = []
|
||||
model_display_names: list[str] = []
|
||||
for override in llm_overrides:
|
||||
llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
llm_override=override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
check_llm_cost_limit_for_provider(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
llm_provider_api_key=llm.config.api_key,
|
||||
)
|
||||
llms.append(llm)
|
||||
model_display_names.append(_build_model_display_name(override))
|
||||
|
||||
# Use first LLM for token counting (context window is checked per-model
|
||||
# but token counting is model-agnostic enough for setup purposes)
|
||||
token_counter = get_llm_token_counter(llms[0])
|
||||
|
||||
verify_user_files(
|
||||
user_files=new_msg_req.file_descriptors,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
project_id=chat_session.project_id,
|
||||
)
|
||||
|
||||
# ── Chat history chain (shared across all models) ──────────────
|
||||
chat_history = create_chat_history_chain(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
|
||||
parent_message = chat_history[-1] if chat_history else root_message
|
||||
elif (
|
||||
new_msg_req.parent_message_id is None
|
||||
or new_msg_req.parent_message_id == root_message.id
|
||||
):
|
||||
parent_message = root_message
|
||||
chat_history = []
|
||||
else:
|
||||
parent_message = None
|
||||
for i in range(len(chat_history) - 1, -1, -1):
|
||||
if chat_history[i].id == new_msg_req.parent_message_id:
|
||||
parent_message = chat_history[i]
|
||||
chat_history = chat_history[: i + 1]
|
||||
break
|
||||
|
||||
if parent_message is None:
|
||||
raise ValueError(
|
||||
"The new message sent is not on the latest mainline of messages"
|
||||
)
|
||||
|
||||
if parent_message.message_type == MessageType.USER:
|
||||
user_message = parent_message
|
||||
else:
|
||||
user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=parent_message,
|
||||
message=message_text,
|
||||
token_count=token_counter(message_text),
|
||||
message_type=MessageType.USER,
|
||||
files=new_msg_req.file_descriptors,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
chat_history.append(user_message)
|
||||
|
||||
available_files = _collect_available_file_ids(
|
||||
chat_history=chat_history,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
summary_message = find_summary_for_branch(db_session, chat_history)
|
||||
summarized_file_metadata: dict[str, FileToolMetadata] = {}
|
||||
if summary_message and summary_message.last_summarized_message_id:
|
||||
cutoff_id = summary_message.last_summarized_message_id
|
||||
for msg in chat_history:
|
||||
if msg.id > cutoff_id or not msg.files:
|
||||
continue
|
||||
for fd in msg.files:
|
||||
file_id = fd.get("id")
|
||||
if not file_id:
|
||||
continue
|
||||
summarized_file_metadata[file_id] = FileToolMetadata(
|
||||
file_id=file_id,
|
||||
filename=fd.get("name") or "unknown",
|
||||
approx_char_count=0,
|
||||
)
|
||||
chat_history = [m for m in chat_history if m.id > cutoff_id]
|
||||
|
||||
user_memory_context = get_memories(user, db_session)
|
||||
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
prompt_memory_context = (
|
||||
user_memory_context
|
||||
if user.use_memories
|
||||
else user_memory_context.without_memories()
|
||||
)
|
||||
|
||||
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
|
||||
custom_agent_prompt or ""
|
||||
)
|
||||
|
||||
reserved_token_count = calculate_reserved_tokens(
|
||||
db_session=db_session,
|
||||
persona_system_prompt=max_reserved_system_prompt_tokens_str,
|
||||
token_counter=token_counter,
|
||||
files=new_msg_req.file_descriptors,
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
context_user_files = resolve_context_user_files(
|
||||
persona=persona,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Use the smallest context window across all models for safety
|
||||
min_context_window = min(llm.config.max_input_tokens for llm in llms)
|
||||
|
||||
extracted_context_files = extract_context_files(
|
||||
user_files=context_user_files,
|
||||
llm_max_context_window=min_context_window,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
search_params = determine_search_params(
|
||||
persona_id=persona.id,
|
||||
project_id=chat_session.project_id,
|
||||
extracted_context_files=extracted_context_files,
|
||||
)
|
||||
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
search_tool_id = next(
|
||||
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
|
||||
None,
|
||||
)
|
||||
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
if (
|
||||
search_params.search_usage == SearchToolUsage.DISABLED
|
||||
and forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
):
|
||||
forced_tool_id = None
|
||||
|
||||
files = load_all_chat_files(chat_history, db_session)
|
||||
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
|
||||
|
||||
# ── Reserve N assistant message IDs ────────────────────────────
|
||||
reserved_messages = reserve_multi_model_message_ids(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=user_message.id,
|
||||
model_display_names=model_display_names,
|
||||
)
|
||||
|
||||
yield MultiModelMessageResponseIDInfo(
|
||||
user_message_id=user_message.id,
|
||||
responses=[
|
||||
ModelResponseSlot(message_id=m.id, model_name=name)
|
||||
for m, name in zip(reserved_messages, model_display_names)
|
||||
],
|
||||
)
|
||||
|
||||
has_file_reader_tool = any(
|
||||
tool.in_code_tool_id == "file_reader" for tool in all_tools
|
||||
)
|
||||
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
context_image_files=extracted_context_files.image_files,
|
||||
additional_context=new_msg_req.additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
)
|
||||
simple_chat_history = chat_history_result.simple_messages
|
||||
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] = (
|
||||
chat_history_result.all_injected_file_metadata
|
||||
if has_file_reader_tool
|
||||
else {}
|
||||
)
|
||||
if summarized_file_metadata:
|
||||
for fid, meta in summarized_file_metadata.items():
|
||||
all_injected_file_metadata.setdefault(fid, meta)
|
||||
|
||||
if summary_message is not None:
|
||||
summary_simple = ChatMessageSimple(
|
||||
message=summary_message.message,
|
||||
token_count=summary_message.token_count,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
simple_chat_history.insert(0, summary_simple)
|
||||
|
||||
# ── Stop signal and processing status ──────────────────────────
|
||||
cache = get_cache_backend()
|
||||
reset_cancel_status(chat_session.id, cache)
|
||||
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, cache)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
value=True,
|
||||
)
|
||||
|
||||
# Release the main session's read transaction before the long stream
|
||||
db_session.commit()
|
||||
|
||||
# ── Parallel model execution ───────────────────────────────────
|
||||
# Each model thread writes tagged packets to this shared queue.
|
||||
# Sentinel _MODEL_DONE signals that a thread finished.
|
||||
merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = (
|
||||
queue.Queue()
|
||||
)
|
||||
|
||||
# Track per-model state containers for completion callbacks
|
||||
state_containers: list[ChatStateContainer] = [
|
||||
ChatStateContainer() for _ in range(n_models)
|
||||
]
|
||||
# Track which models completed successfully (for completion callbacks)
|
||||
model_succeeded: list[bool] = [False] * n_models
|
||||
|
||||
user_identity = LLMUserIdentity(
|
||||
user_id=llm_user_identifier,
|
||||
session_id=str(chat_session.id),
|
||||
)
|
||||
|
||||
def _run_model(model_idx: int) -> None:
|
||||
"""Run a single model in a worker thread.
|
||||
|
||||
Uses _ModelIndexEmitter so packets flow directly to merged_queue
|
||||
in real-time (not batched after completion). This enables true
|
||||
parallel streaming where both models' tokens interleave on the wire.
|
||||
|
||||
DB access: tools may need a session during execution (e.g., search
|
||||
tool). Each thread creates its own session via context manager.
|
||||
"""
|
||||
model_emitter = _ModelIndexEmitter(model_idx, merged_queue)
|
||||
sc = state_containers[model_idx]
|
||||
model_llm = llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each model thread gets its own DB session for tool execution.
|
||||
# The session is scoped to the thread and closed when done.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
# Construct tools per-thread with thread-local DB session
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id_filter=search_params.project_id_filter,
|
||||
persona_id_filter=search_params.persona_id_filter,
|
||||
bypass_acl=False,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
persona, new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session.id,
|
||||
message_id=user_message.id,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=available_files.user_file_ids,
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=search_params.search_usage,
|
||||
)
|
||||
model_tools: list[Tool] = []
|
||||
for tool_list in thread_tool_dict.values():
|
||||
model_tools.extend(tool_list)
|
||||
|
||||
# Run the LLM loop — this blocks until the model finishes.
|
||||
# Packets flow to merged_queue in real-time via the emitter.
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
context_files=extracted_context_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
except Exception as e:
|
||||
merged_queue.put((model_idx, e))
|
||||
|
||||
finally:
|
||||
merged_queue.put((model_idx, _MODEL_DONE))
|
||||
|
||||
# Launch model threads via ThreadPoolExecutor (bounded, context-propagating)
|
||||
executor = ThreadPoolExecutor(
|
||||
max_workers=n_models,
|
||||
thread_name_prefix="multi-model",
|
||||
)
|
||||
futures = []
|
||||
try:
|
||||
for i in range(n_models):
|
||||
futures.append(executor.submit(_run_model, i))
|
||||
|
||||
# ── Main thread: merge and yield packets ───────────────────
|
||||
models_remaining = n_models
|
||||
while models_remaining > 0:
|
||||
try:
|
||||
model_idx, item = merged_queue.get(timeout=0.3)
|
||||
except queue.Empty:
|
||||
# Check cancellation during idle periods
|
||||
if not check_is_connected():
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
return
|
||||
continue
|
||||
else:
|
||||
if item is _MODEL_DONE:
|
||||
models_remaining -= 1
|
||||
continue
|
||||
|
||||
if isinstance(item, Exception):
|
||||
# Yield error as a tagged StreamingError packet.
|
||||
# Do NOT decrement models_remaining here — the finally block
|
||||
# in _run_model always posts _MODEL_DONE, which is the sole
|
||||
# completion signal. Decrementing here too would double-count
|
||||
# and cause the loop to exit early, silently dropping the
|
||||
# surviving models' responses.
|
||||
error_msg = str(item)
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(
|
||||
type(item), item, item.__traceback__
|
||||
)
|
||||
)
|
||||
# Redact API keys from error messages
|
||||
model_llm = llms[model_idx]
|
||||
if (
|
||||
model_llm.config.api_key
|
||||
and len(model_llm.config.api_key) > 2
|
||||
):
|
||||
error_msg = error_msg.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
stack_trace = stack_trace.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="MODEL_ERROR",
|
||||
is_retryable=True,
|
||||
details={
|
||||
"model": model_llm.config.model_name,
|
||||
"provider": model_llm.config.model_provider,
|
||||
"model_index": model_idx,
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(item, Packet):
|
||||
# Packet is already tagged with model_index by _ModelIndexEmitter
|
||||
yield item
|
||||
|
||||
# ── Completion: save each successful model's response ──────
|
||||
# Run completion callbacks on the main thread using the main
|
||||
# session. This is safe because all worker threads have exited
|
||||
# by this point (merged_queue fully drained).
|
||||
for i in range(n_models):
|
||||
if not model_succeeded[i]:
|
||||
continue
|
||||
try:
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_containers[i],
|
||||
is_connected=check_is_connected,
|
||||
db_session=db_session,
|
||||
assistant_message=reserved_messages[i],
|
||||
llm=llms[i],
|
||||
reserved_tokens=reserved_token_count,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed completion for model {i} "
|
||||
f"({model_display_names[i]})"
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop", stop_reason="complete"),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Don't block on shutdown — futures making live LLM API calls
|
||||
# cannot be cancelled once started, so wait=True would block
|
||||
# the generator (and the HTTP response) until all calls finish.
|
||||
# wait=False lets threads complete in the background.
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process multi-model chat message.")
|
||||
yield StreamingError(
|
||||
error=str(e),
|
||||
error_code="VALIDATION_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed multi-model chat: {e}")
|
||||
stack_trace = traceback.format_exc()
|
||||
yield StreamingError(
|
||||
error=str(e),
|
||||
stack_trace=stack_trace,
|
||||
error_code="MULTI_MODEL_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
finally:
|
||||
try:
|
||||
if cache is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error clearing processing status")
|
||||
|
||||
|
||||
def llm_loop_completion_handle(
|
||||
state_container: ChatStateContainer,
|
||||
is_connected: Callable[[], bool],
|
||||
|
||||
@@ -617,6 +617,80 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model user message.
|
||||
|
||||
Validates that the user message is a USER type and that the preferred
|
||||
assistant message is a direct child of that user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -839,6 +913,8 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.process_message import run_multi_model_stream
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -570,6 +575,46 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in run_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -660,6 +705,30 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
207
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
207
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in run_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
|
||||
"""Advance the generator one step to trigger early validation."""
|
||||
from onyx.chat.process_message import run_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = run_multi_model_stream(req, user, db, overrides)
|
||||
# Calling next() executes until the first yield OR raises.
|
||||
# Validation errors are raised before any yield.
|
||||
next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_raises(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
|
||||
def test_four_overrides_raises(self) -> None:
|
||||
"""4 overrides exceeds maximum — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_zero_overrides_raises(self) -> None:
|
||||
"""Empty override list raises."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [])
|
||||
|
||||
def test_deep_research_raises(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model."""
|
||||
req = _make_request(deep_research=True)
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
|
||||
req = _make_request()
|
||||
# 1 override must fail
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
|
||||
try:
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
except ValueError as exc:
|
||||
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
|
||||
except Exception:
|
||||
pass # Any other error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
assert user_msg.latest_child_message_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
301
web/src/app/app/dev/multi-model-preview/page.tsx
Normal file
301
web/src/app/app/dev/multi-model-preview/page.tsx
Normal file
@@ -0,0 +1,301 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import MultiModelResponseView from "@/app/app/message/MultiModelResponseView";
|
||||
import type { MultiModelResponse } from "@/app/app/message/interfaces";
|
||||
import ModelSelector, {
|
||||
SelectedModel,
|
||||
} from "@/refresh-components/popovers/ModelSelector";
|
||||
import { Packet, StopReason } from "@/app/app/services/streamingModels";
|
||||
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
|
||||
import type { LlmManager } from "@/lib/hooks";
|
||||
import type { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
|
||||
// ────────────────────────────────────────────────────────────
|
||||
// Mock data
|
||||
// ────────────────────────────────────────────────────────────
|
||||
|
||||
function makeTextPackets(text: string): Packet[] {
|
||||
return [
|
||||
{
|
||||
placement: { turn_index: 0 },
|
||||
obj: {
|
||||
type: "message_start" as const,
|
||||
id: "msg-1",
|
||||
content: "",
|
||||
final_documents: null,
|
||||
},
|
||||
},
|
||||
{
|
||||
placement: { turn_index: 0 },
|
||||
obj: { type: "message_delta" as const, content: text },
|
||||
},
|
||||
{
|
||||
placement: { turn_index: 0 },
|
||||
obj: { type: "message_end" as const },
|
||||
},
|
||||
{
|
||||
placement: { turn_index: 0 },
|
||||
obj: { type: "stop" as const, stop_reason: StopReason.FINISHED },
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
const mockChatState: FullChatState = {
|
||||
agent: {
|
||||
id: 0,
|
||||
name: "Test Agent",
|
||||
description: "",
|
||||
tools: [],
|
||||
starter_messages: null,
|
||||
document_sets: [],
|
||||
is_public: true,
|
||||
is_listed: false,
|
||||
display_priority: null,
|
||||
is_featured: false,
|
||||
builtin_persona: true,
|
||||
owner: null,
|
||||
},
|
||||
};
|
||||
|
||||
const PACKETS = {
|
||||
anthropic: makeTextPackets(
|
||||
"**Claude Opus 4.6** here.\n\nThe universe is fundamentally beautiful. Every particle dances to the rhythm of quantum fields, and consciousness itself may be the universe observing itself through our eyes.\n\nI recommend we consider both the mathematical elegance and the philosophical implications."
|
||||
),
|
||||
openai: makeTextPackets(
|
||||
"**GPT-4o** here.\n\nLet me break this down systematically:\n\n1. First, analyze the core problem\n2. Evaluate available options\n3. Select the optimal solution based on criteria\n\nThis structured approach ensures comprehensive coverage."
|
||||
),
|
||||
google: makeTextPackets(
|
||||
"**Gemini Pro** here.\n\nInteresting question! Based on my analysis, there are multiple valid perspectives to consider. The key insight is that context matters enormously—what works in one situation may not work in another."
|
||||
),
|
||||
};
|
||||
|
||||
const mockProviders: LLMProviderDescriptor[] = [
|
||||
{
|
||||
id: 1,
|
||||
name: "anthropic",
|
||||
provider: "anthropic",
|
||||
provider_display_name: "Anthropic",
|
||||
model_configurations: [
|
||||
{
|
||||
name: "claude-opus-4-6",
|
||||
display_name: "Claude Opus 4.6",
|
||||
is_visible: true,
|
||||
max_input_tokens: 200000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: true,
|
||||
},
|
||||
{
|
||||
name: "claude-sonnet-4-6",
|
||||
display_name: "Claude Sonnet 4.6",
|
||||
is_visible: true,
|
||||
max_input_tokens: 200000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 2,
|
||||
name: "openai",
|
||||
provider: "openai",
|
||||
provider_display_name: "OpenAI",
|
||||
model_configurations: [
|
||||
{
|
||||
name: "gpt-4o",
|
||||
display_name: "GPT-4o",
|
||||
is_visible: true,
|
||||
max_input_tokens: 128000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-4o-mini",
|
||||
display_name: "GPT-4o Mini",
|
||||
is_visible: true,
|
||||
max_input_tokens: 128000,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 3,
|
||||
name: "google",
|
||||
provider: "google",
|
||||
provider_display_name: "Google",
|
||||
model_configurations: [
|
||||
{
|
||||
name: "gemini-pro",
|
||||
display_name: "Gemini Pro",
|
||||
is_visible: true,
|
||||
max_input_tokens: 1000000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const mockLlmManager = {
|
||||
llmProviders: mockProviders,
|
||||
currentLlm: {
|
||||
name: "claude-opus-4-6",
|
||||
provider: "anthropic",
|
||||
modelName: "claude-opus-4-6",
|
||||
},
|
||||
isLoadingProviders: false,
|
||||
} as unknown as LlmManager;
|
||||
|
||||
// ────────────────────────────────────────────────────────────
|
||||
// All 3 responses (constant — we always pass all 3 to the view)
|
||||
// ────────────────────────────────────────────────────────────
|
||||
|
||||
function buildResponses(generating: boolean): MultiModelResponse[] {
|
||||
return [
|
||||
{
|
||||
modelIndex: 0,
|
||||
provider: "anthropic",
|
||||
modelName: "claude-opus-4-6",
|
||||
displayName: "Claude Opus 4.6",
|
||||
packets: generating ? PACKETS.anthropic.slice(0, 2) : PACKETS.anthropic,
|
||||
packetCount: generating ? 2 : PACKETS.anthropic.length,
|
||||
nodeId: 1,
|
||||
messageId: 1,
|
||||
isGenerating: generating,
|
||||
},
|
||||
{
|
||||
modelIndex: 1,
|
||||
provider: "openai",
|
||||
modelName: "gpt-4o",
|
||||
displayName: "GPT-4o",
|
||||
packets: generating ? [] : PACKETS.openai,
|
||||
packetCount: generating ? 0 : PACKETS.openai.length,
|
||||
nodeId: 2,
|
||||
messageId: 2,
|
||||
isGenerating: generating,
|
||||
},
|
||||
{
|
||||
modelIndex: 2,
|
||||
provider: "google",
|
||||
modelName: "gemini-pro",
|
||||
displayName: "Gemini Pro",
|
||||
packets: generating ? [] : PACKETS.google,
|
||||
packetCount: generating ? 0 : PACKETS.google.length,
|
||||
nodeId: 3,
|
||||
messageId: 3,
|
||||
isGenerating: generating,
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────
|
||||
// Page
|
||||
// ────────────────────────────────────────────────────────────
|
||||
|
||||
export default function MultiModelPreviewPage() {
|
||||
const [isGenerating, setIsGenerating] = useState(false);
|
||||
const [modelCount, setModelCount] = useState<2 | 3>(2);
|
||||
|
||||
// Selector state (independent from the response view)
|
||||
const [selectedModels, setSelectedModels] = useState<SelectedModel[]>([
|
||||
{
|
||||
name: "claude-opus-4-6",
|
||||
provider: "anthropic",
|
||||
modelName: "claude-opus-4-6",
|
||||
displayName: "Claude Opus 4.6",
|
||||
},
|
||||
{
|
||||
name: "gpt-4o",
|
||||
provider: "openai",
|
||||
modelName: "gpt-4o",
|
||||
displayName: "GPT-4o",
|
||||
},
|
||||
]);
|
||||
|
||||
const allResponses = buildResponses(isGenerating);
|
||||
const visibleResponses = allResponses.slice(0, modelCount);
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-background-neutral-01">
|
||||
{/* Top bar */}
|
||||
<div className="sticky top-0 z-50 flex items-center gap-6 px-8 py-3 bg-background-tint-01 border-b border-border-01">
|
||||
<span className="text-text-04 font-semibold text-sm">
|
||||
Multi-Model Preview
|
||||
</span>
|
||||
|
||||
<label className="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={isGenerating}
|
||||
onChange={(e) => setIsGenerating(e.target.checked)}
|
||||
className="w-4 h-4 accent-action-link-05"
|
||||
/>
|
||||
<span className="text-text-03 text-sm">Generating</span>
|
||||
</label>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-text-03 text-sm">Panels:</span>
|
||||
{([2, 3] as const).map((n) => (
|
||||
<button
|
||||
key={n}
|
||||
onClick={() => setModelCount(n)}
|
||||
className={`px-2 py-0.5 rounded text-sm transition-colors ${
|
||||
modelCount === n
|
||||
? "bg-action-link-01 text-action-link-03 font-semibold"
|
||||
: "text-text-03 hover:bg-background-tint-02"
|
||||
}`}
|
||||
>
|
||||
{n}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<div className="flex-1" />
|
||||
|
||||
<div className="text-text-03 text-xs">
|
||||
Click a panel to select preferred · X to hide · eye to restore
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="max-w-screen-2xl mx-auto px-8 py-8 space-y-12">
|
||||
{/* ── Section 1: Model Selector ── */}
|
||||
<section className="space-y-3">
|
||||
<div className="text-xs font-semibold text-text-03 uppercase tracking-widest">
|
||||
Model Selector (input bar)
|
||||
</div>
|
||||
<div className="inline-flex bg-background-tint-01 rounded-12 border border-border-01">
|
||||
<ModelSelector
|
||||
llmManager={mockLlmManager}
|
||||
selectedModels={selectedModels}
|
||||
onAdd={(model) => setSelectedModels((prev) => [...prev, model])}
|
||||
onRemove={(index) =>
|
||||
setSelectedModels((prev) => prev.filter((_, i) => i !== index))
|
||||
}
|
||||
onReplace={(index, model) =>
|
||||
setSelectedModels((prev) => {
|
||||
const next = [...prev];
|
||||
next[index] = model;
|
||||
return next;
|
||||
})
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
{/* ── Section 2: Response View ── */}
|
||||
<section className="space-y-3">
|
||||
<div className="text-xs font-semibold text-text-03 uppercase tracking-widest">
|
||||
Response View ({modelCount} models ·{" "}
|
||||
{isGenerating ? "generating" : "complete"})
|
||||
</div>
|
||||
<MultiModelResponseView
|
||||
responses={visibleResponses}
|
||||
chatState={mockChatState}
|
||||
llmManager={null}
|
||||
/>
|
||||
</section>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -159,6 +159,10 @@ export interface Message {
|
||||
overridden_model?: string;
|
||||
stopReason?: StreamStopReason | null;
|
||||
|
||||
// Multi-model answer generation
|
||||
preferredResponseId?: number | null;
|
||||
modelDisplayName?: string | null;
|
||||
|
||||
// new gen
|
||||
packets: Packet[];
|
||||
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
|
||||
@@ -231,6 +235,9 @@ export interface BackendMessage {
|
||||
parentMessageId: number | null;
|
||||
refined_answer_improvement: boolean | null;
|
||||
is_agentic: boolean | null;
|
||||
// Multi-model answer generation
|
||||
preferred_response_id: number | null;
|
||||
model_display_name: string | null;
|
||||
}
|
||||
|
||||
export interface MessageResponseIDInfo {
|
||||
@@ -238,6 +245,16 @@ export interface MessageResponseIDInfo {
|
||||
reserved_assistant_message_id: number; // TODO: rename to agent — https://linear.app/onyx-app/issue/ENG-3766
|
||||
}
|
||||
|
||||
export interface ModelResponseSlot {
|
||||
message_id: number;
|
||||
model_name: string;
|
||||
}
|
||||
|
||||
export interface MultiModelMessageResponseIDInfo {
|
||||
user_message_id: number | null;
|
||||
responses: ModelResponseSlot[];
|
||||
}
|
||||
|
||||
export interface UserKnowledgeFilePacket {
|
||||
user_files: FileDescriptor[];
|
||||
}
|
||||
|
||||
137
web/src/app/app/message/MultiModelPanel.tsx
Normal file
137
web/src/app/app/message/MultiModelPanel.tsx
Normal file
@@ -0,0 +1,137 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback } from "react";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgEyeClosed, SvgX } from "@opal/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
|
||||
import AgentMessage, {
|
||||
AgentMessageProps,
|
||||
} from "@/app/app/message/messageComponents/AgentMessage";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface MultiModelPanelProps {
|
||||
modelIndex: number;
|
||||
/** Provider name for icon lookup */
|
||||
provider: string;
|
||||
/** Model name for icon lookup and display */
|
||||
modelName: string;
|
||||
/** Display-friendly model name */
|
||||
displayName: string;
|
||||
/** Whether this panel is the preferred/selected response */
|
||||
isPreferred: boolean;
|
||||
/** Whether this panel is currently hidden */
|
||||
isHidden: boolean;
|
||||
/** Whether this is a non-preferred panel in selection mode (pushed off-screen) */
|
||||
isNonPreferredInSelection: boolean;
|
||||
/** Callback when user clicks this panel to select as preferred */
|
||||
onSelect: () => void;
|
||||
/** Callback to hide/show this panel */
|
||||
onToggleVisibility: () => void;
|
||||
/** Props to pass through to AgentMessage */
|
||||
agentMessageProps: AgentMessageProps;
|
||||
}
|
||||
|
||||
export default function MultiModelPanel({
|
||||
modelIndex,
|
||||
provider,
|
||||
modelName,
|
||||
displayName,
|
||||
isPreferred,
|
||||
isHidden,
|
||||
isNonPreferredInSelection,
|
||||
onSelect,
|
||||
onToggleVisibility,
|
||||
agentMessageProps,
|
||||
}: MultiModelPanelProps) {
|
||||
const ProviderIcon = getProviderIcon(provider, modelName);
|
||||
|
||||
const handlePanelClick = useCallback(() => {
|
||||
if (!isHidden) onSelect();
|
||||
}, [isHidden, onSelect]);
|
||||
|
||||
// Hidden/collapsed panel — compact strip at fixed 220px
|
||||
if (isHidden) {
|
||||
return (
|
||||
<div className="flex items-center gap-1.5 w-[220px] shrink-0 rounded-08 bg-background-tint-00 px-2 py-1 opacity-50 hover:opacity-100 transition-opacity cursor-pointer">
|
||||
<div className="flex items-center justify-center size-5 shrink-0">
|
||||
<ProviderIcon size={16} />
|
||||
</div>
|
||||
<Text
|
||||
secondaryBody
|
||||
text02
|
||||
nowrap
|
||||
className="line-through flex-1 min-w-0 truncate"
|
||||
>
|
||||
{displayName}
|
||||
</Text>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgEyeClosed}
|
||||
size="2xs"
|
||||
onClick={onToggleVisibility}
|
||||
tooltip="Show response"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Section
|
||||
flexDirection="column"
|
||||
alignItems="stretch"
|
||||
justifyContent="start"
|
||||
height="fit"
|
||||
gap={0.75}
|
||||
className={cn(
|
||||
"min-w-0 cursor-pointer rounded-16 transition-colors",
|
||||
!isPreferred && "hover:bg-background-tint-02"
|
||||
)}
|
||||
onClick={handlePanelClick}
|
||||
>
|
||||
{/* Panel header */}
|
||||
<Section
|
||||
flexDirection="row"
|
||||
alignItems="center"
|
||||
justifyContent="start"
|
||||
height="fit"
|
||||
gap={0.375}
|
||||
className={cn(
|
||||
"rounded-12 px-2 py-1",
|
||||
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center justify-center size-5 shrink-0">
|
||||
<ProviderIcon size={16} />
|
||||
</div>
|
||||
<Text mainUiAction text04 nowrap className="flex-1 min-w-0 truncate">
|
||||
{displayName}
|
||||
</Text>
|
||||
{isPreferred && (
|
||||
<Text secondaryBody nowrap className="text-action-link-05 shrink-0">
|
||||
Preferred Response
|
||||
</Text>
|
||||
)}
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgX}
|
||||
size="2xs"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onToggleVisibility();
|
||||
}}
|
||||
tooltip="Hide response"
|
||||
/>
|
||||
</Section>
|
||||
|
||||
{/* Response body */}
|
||||
<div className={cn(isNonPreferredInSelection && "pointer-events-none")}>
|
||||
<AgentMessage
|
||||
{...agentMessageProps}
|
||||
hideFooter={isNonPreferredInSelection}
|
||||
/>
|
||||
</div>
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
250
web/src/app/app/message/MultiModelResponseView.tsx
Normal file
250
web/src/app/app/message/MultiModelResponseView.tsx
Normal file
@@ -0,0 +1,250 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useCallback, useMemo, useEffect } from "react";
|
||||
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
|
||||
import { Message } from "@/app/app/interfaces";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { RegenerationFactory } from "@/app/app/message/messageComponents/AgentMessage";
|
||||
import MultiModelPanel from "@/app/app/message/MultiModelPanel";
|
||||
import { MultiModelResponse } from "@/app/app/message/interfaces";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface MultiModelResponseViewProps {
|
||||
responses: MultiModelResponse[];
|
||||
chatState: FullChatState;
|
||||
llmManager: LlmManager | null;
|
||||
onRegenerate?: RegenerationFactory;
|
||||
parentMessage?: Message | null;
|
||||
otherMessagesCanSwitchTo?: number[];
|
||||
onMessageSelection?: (nodeId: number) => void;
|
||||
}
|
||||
|
||||
// How many pixels of a non-preferred panel are visible at the viewport edge
|
||||
const PEEK_W = 64;
|
||||
// Width of each non-preferred panel in the selection layout
|
||||
const PANEL_W = 400;
|
||||
// Gap between panels
|
||||
const PANEL_GAP = 16;
|
||||
|
||||
export default function MultiModelResponseView({
|
||||
responses,
|
||||
chatState,
|
||||
llmManager,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
}: MultiModelResponseViewProps) {
|
||||
const [preferredIndex, setPreferredIndex] = useState<number | null>(null);
|
||||
const [hiddenPanels, setHiddenPanels] = useState<Set<number>>(new Set());
|
||||
// Controls animation: false = panels at start position, true = panels at peek position
|
||||
const [selectionEntered, setSelectionEntered] = useState(false);
|
||||
|
||||
const isGenerating = useMemo(
|
||||
() => responses.some((r) => r.isGenerating),
|
||||
[responses]
|
||||
);
|
||||
|
||||
const visibleResponses = useMemo(
|
||||
() => responses.filter((r) => !hiddenPanels.has(r.modelIndex)),
|
||||
[responses, hiddenPanels]
|
||||
);
|
||||
|
||||
const hiddenResponses = useMemo(
|
||||
() => responses.filter((r) => hiddenPanels.has(r.modelIndex)),
|
||||
[responses, hiddenPanels]
|
||||
);
|
||||
|
||||
const toggleVisibility = useCallback(
|
||||
(modelIndex: number) => {
|
||||
setHiddenPanels((prev) => {
|
||||
const next = new Set(prev);
|
||||
if (next.has(modelIndex)) {
|
||||
next.delete(modelIndex);
|
||||
} else {
|
||||
// Don't hide the last visible panel
|
||||
const visibleCount = responses.length - next.size;
|
||||
if (visibleCount <= 1) return prev;
|
||||
next.add(modelIndex);
|
||||
}
|
||||
return next;
|
||||
});
|
||||
},
|
||||
[responses.length]
|
||||
);
|
||||
|
||||
const handleSelectPreferred = useCallback(
|
||||
(modelIndex: number) => {
|
||||
setPreferredIndex(modelIndex);
|
||||
const response = responses[modelIndex];
|
||||
if (!response) return;
|
||||
if (onMessageSelection) {
|
||||
onMessageSelection(response.nodeId);
|
||||
}
|
||||
},
|
||||
[responses, onMessageSelection]
|
||||
);
|
||||
|
||||
// Selection mode when preferred is set and not generating
|
||||
const showSelectionMode =
|
||||
preferredIndex !== null && !isGenerating && visibleResponses.length > 1;
|
||||
|
||||
// Trigger the slide-out animation one frame after entering selection mode
|
||||
useEffect(() => {
|
||||
if (!showSelectionMode) {
|
||||
setSelectionEntered(false);
|
||||
return;
|
||||
}
|
||||
const raf = requestAnimationFrame(() => setSelectionEntered(true));
|
||||
return () => cancelAnimationFrame(raf);
|
||||
}, [showSelectionMode]);
|
||||
|
||||
// Build common panel props
|
||||
const buildPanelProps = useCallback(
|
||||
(response: MultiModelResponse, isNonPreferred: boolean) => ({
|
||||
modelIndex: response.modelIndex,
|
||||
provider: response.provider,
|
||||
modelName: response.modelName,
|
||||
displayName: response.displayName,
|
||||
isPreferred: preferredIndex === response.modelIndex,
|
||||
isHidden: false as const,
|
||||
isNonPreferredInSelection: isNonPreferred,
|
||||
onSelect: () => handleSelectPreferred(response.modelIndex),
|
||||
onToggleVisibility: () => toggleVisibility(response.modelIndex),
|
||||
agentMessageProps: {
|
||||
rawPackets: response.packets,
|
||||
packetCount: response.packetCount,
|
||||
chatState,
|
||||
nodeId: response.nodeId,
|
||||
messageId: response.messageId,
|
||||
currentFeedback: response.currentFeedback,
|
||||
llmManager,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
},
|
||||
}),
|
||||
[
|
||||
preferredIndex,
|
||||
handleSelectPreferred,
|
||||
toggleVisibility,
|
||||
chatState,
|
||||
llmManager,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
]
|
||||
);
|
||||
|
||||
// Shared renderer for hidden panels (inline in the flex row)
|
||||
const renderHiddenPanels = () =>
|
||||
hiddenResponses.map((r) => (
|
||||
<MultiModelPanel
|
||||
key={r.modelIndex}
|
||||
modelIndex={r.modelIndex}
|
||||
provider={r.provider}
|
||||
modelName={r.modelName}
|
||||
displayName={r.displayName}
|
||||
isPreferred={false}
|
||||
isHidden
|
||||
isNonPreferredInSelection={false}
|
||||
onSelect={() => handleSelectPreferred(r.modelIndex)}
|
||||
onToggleVisibility={() => toggleVisibility(r.modelIndex)}
|
||||
agentMessageProps={buildPanelProps(r, false).agentMessageProps}
|
||||
/>
|
||||
));
|
||||
|
||||
if (showSelectionMode) {
|
||||
// ── Selection Layout ──
|
||||
// Preferred panel stays centered at normal chat width.
|
||||
// Non-preferred panels are in a carousel: they peek from the viewport edges
|
||||
// with a fade, and animate in from adjacent to preferred on first render.
|
||||
const preferredIdx = visibleResponses.findIndex(
|
||||
(r) => r.modelIndex === preferredIndex
|
||||
);
|
||||
const preferred = visibleResponses[preferredIdx];
|
||||
const leftPanels = visibleResponses.slice(0, preferredIdx);
|
||||
const rightPanels = visibleResponses.slice(preferredIdx + 1);
|
||||
|
||||
// Peek position: panel's visible edge is at PEEK_W from container edge.
|
||||
// right: calc(100% - PEEK_W) → panel's right edge is PEEK_W from container left.
|
||||
// left: calc(100% - PEEK_W) → panel's left edge is PEEK_W from container right.
|
||||
//
|
||||
// Start position (for entry animation): panels start adjacent to the preferred panel
|
||||
// and slide out to peek position on the frame after mount.
|
||||
|
||||
const getLeftPanelStyle = (i: number): React.CSSProperties => ({
|
||||
width: `${PANEL_W}px`,
|
||||
transition: "right 0.45s cubic-bezier(0.2, 0, 0, 1)",
|
||||
right: selectionEntered
|
||||
? `calc(100% - ${PEEK_W - i * (PANEL_W + PANEL_GAP)}px)`
|
||||
: `calc(50% + 320px + ${PANEL_GAP + i * (PANEL_W + PANEL_GAP)}px)`,
|
||||
});
|
||||
|
||||
const getRightPanelStyle = (i: number): React.CSSProperties => ({
|
||||
width: `${PANEL_W}px`,
|
||||
transition: "left 0.45s cubic-bezier(0.2, 0, 0, 1)",
|
||||
left: selectionEntered
|
||||
? `calc(100% - ${PEEK_W - i * (PANEL_W + PANEL_GAP)}px)`
|
||||
: `calc(50% + 320px + ${PANEL_GAP + i * (PANEL_W + PANEL_GAP)}px)`,
|
||||
});
|
||||
|
||||
return (
|
||||
<div
|
||||
className="w-full relative overflow-hidden"
|
||||
style={{
|
||||
// Fade the viewport edges so peeking panels dissolve naturally.
|
||||
// Fade zone = PEEK_W px; center is fully opaque (preferred panel unaffected).
|
||||
maskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
WebkitMaskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
}}
|
||||
>
|
||||
{/* Preferred — centered, in normal flow to establish container height */}
|
||||
{preferred && (
|
||||
<div className="w-full max-w-[640px] min-w-[400px] mx-auto">
|
||||
<MultiModelPanel {...buildPanelProps(preferred, false)} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Non-preferred on the left — animate from adjacent to peeking */}
|
||||
{leftPanels.map((r, i) => (
|
||||
<div
|
||||
key={r.modelIndex}
|
||||
className="absolute top-0"
|
||||
style={getLeftPanelStyle(i)}
|
||||
>
|
||||
<MultiModelPanel {...buildPanelProps(r, true)} />
|
||||
</div>
|
||||
))}
|
||||
|
||||
{/* Non-preferred on the right — animate from adjacent to peeking */}
|
||||
{rightPanels.map((r, i) => (
|
||||
<div
|
||||
key={r.modelIndex}
|
||||
className="absolute top-0"
|
||||
style={getRightPanelStyle(i)}
|
||||
>
|
||||
<MultiModelPanel {...buildPanelProps(r, true)} />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ── Generation Layout (equal panels side-by-side) ──
|
||||
return (
|
||||
<div className="flex gap-6 items-start justify-center">
|
||||
{visibleResponses.map((r) => (
|
||||
<div
|
||||
key={r.modelIndex}
|
||||
className={cn("flex-1 min-w-[400px] max-w-[640px]")}
|
||||
>
|
||||
<MultiModelPanel {...buildPanelProps(r, false)} />
|
||||
</div>
|
||||
))}
|
||||
{renderHiddenPanels()}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
16
web/src/app/app/message/interfaces.ts
Normal file
16
web/src/app/app/message/interfaces.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { Packet } from "@/app/app/services/streamingModels";
|
||||
import { FeedbackType } from "@/app/app/interfaces";
|
||||
|
||||
export interface MultiModelResponse {
|
||||
modelIndex: number;
|
||||
provider: string;
|
||||
modelName: string;
|
||||
displayName: string;
|
||||
packets: Packet[];
|
||||
packetCount: number;
|
||||
nodeId: number;
|
||||
messageId?: number;
|
||||
isHighlighted?: boolean;
|
||||
currentFeedback?: FeedbackType | null;
|
||||
isGenerating?: boolean;
|
||||
}
|
||||
@@ -49,6 +49,8 @@ export interface AgentMessageProps {
|
||||
parentMessage?: Message | null;
|
||||
// Duration in seconds for processing this message (agent messages only)
|
||||
processingDurationSeconds?: number;
|
||||
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
|
||||
hideFooter?: boolean;
|
||||
}
|
||||
|
||||
// TODO: Consider more robust comparisons:
|
||||
@@ -76,7 +78,8 @@ function arePropsEqual(
|
||||
prev.parentMessage?.messageId === next.parentMessage?.messageId &&
|
||||
prev.llmManager?.isLoadingProviders ===
|
||||
next.llmManager?.isLoadingProviders &&
|
||||
prev.processingDurationSeconds === next.processingDurationSeconds
|
||||
prev.processingDurationSeconds === next.processingDurationSeconds &&
|
||||
prev.hideFooter === next.hideFooter
|
||||
// Skip: chatState.regenerate, chatState.setPresentingDocument,
|
||||
// most of llmManager, onMessageSelection (function/object props)
|
||||
);
|
||||
@@ -95,6 +98,7 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
processingDurationSeconds,
|
||||
hideFooter,
|
||||
}: AgentMessageProps) {
|
||||
const markdownRef = useRef<HTMLDivElement>(null);
|
||||
const finalAnswerRef = useRef<HTMLDivElement>(null);
|
||||
@@ -326,7 +330,7 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
</div>
|
||||
|
||||
{/* Feedback buttons - only show when streaming and rendering complete */}
|
||||
{isComplete && (
|
||||
{isComplete && !hideFooter && (
|
||||
<MessageToolbar
|
||||
nodeId={nodeId}
|
||||
messageId={messageId}
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
FileChatDisplay,
|
||||
Message,
|
||||
MessageResponseIDInfo,
|
||||
MultiModelMessageResponseIDInfo,
|
||||
ResearchType,
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
@@ -96,6 +97,7 @@ export type PacketType =
|
||||
| FileChatDisplay
|
||||
| StreamingError
|
||||
| MessageResponseIDInfo
|
||||
| MultiModelMessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| UserKnowledgeFilePacket
|
||||
| Packet;
|
||||
@@ -109,6 +111,13 @@ export type MessageOrigin =
|
||||
| "slackbot"
|
||||
| "unknown";
|
||||
|
||||
export interface LLMOverride {
|
||||
model_provider: string;
|
||||
model_version: string;
|
||||
temperature?: number;
|
||||
display_name?: string;
|
||||
}
|
||||
|
||||
export interface SendMessageParams {
|
||||
message: string;
|
||||
fileDescriptors?: FileDescriptor[];
|
||||
@@ -124,6 +133,8 @@ export interface SendMessageParams {
|
||||
modelProvider?: string;
|
||||
modelVersion?: string;
|
||||
temperature?: number;
|
||||
// Multi-model: send multiple LLM overrides for parallel generation
|
||||
llmOverrides?: LLMOverride[];
|
||||
// Origin of the message for telemetry tracking
|
||||
origin?: MessageOrigin;
|
||||
// Additional context injected into the LLM call but not stored/shown in chat.
|
||||
@@ -144,6 +155,7 @@ export async function* sendMessage({
|
||||
modelProvider,
|
||||
modelVersion,
|
||||
temperature,
|
||||
llmOverrides,
|
||||
origin,
|
||||
additionalContext,
|
||||
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
|
||||
@@ -165,6 +177,8 @@ export async function* sendMessage({
|
||||
model_version: modelVersion,
|
||||
}
|
||||
: null,
|
||||
// Multi-model: list of LLM overrides for parallel generation
|
||||
llm_overrides: llmOverrides ?? null,
|
||||
// Default to "unknown" for consistency with backend; callers should set explicitly
|
||||
origin: origin ?? "unknown",
|
||||
additional_context: additionalContext ?? null,
|
||||
@@ -188,6 +202,20 @@ export async function* sendMessage({
|
||||
yield* handleSSEStream<PacketType>(response, signal);
|
||||
}
|
||||
|
||||
export async function setPreferredResponse(
|
||||
userMessageId: number,
|
||||
preferredResponseId: number
|
||||
): Promise<Response> {
|
||||
return fetch("/api/chat/set-preferred-response", {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
user_message_id: userMessageId,
|
||||
preferred_response_id: preferredResponseId,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
export async function nameChatSession(chatSessionId: string) {
|
||||
const response = await fetch("/api/chat/rename-chat-session", {
|
||||
method: "PUT",
|
||||
@@ -357,6 +385,9 @@ export function processRawChatHistory(
|
||||
overridden_model: messageInfo.overridden_model,
|
||||
packets: packetsForMessage || [],
|
||||
currentFeedback: messageInfo.current_feedback as FeedbackType | null,
|
||||
// Multi-model answer generation
|
||||
preferredResponseId: messageInfo.preferred_response_id ?? null,
|
||||
modelDisplayName: messageInfo.model_display_name ?? null,
|
||||
};
|
||||
|
||||
messages.set(messageInfo.message_id, message);
|
||||
|
||||
@@ -403,6 +403,7 @@ export interface Placement {
|
||||
turn_index: number;
|
||||
tab_index?: number; // For parallel tool calls - tools with same turn_index but different tab_index run in parallel
|
||||
sub_turn_index?: number | null;
|
||||
model_index?: number | null; // For multi-model answer generation - identifies which model produced this packet
|
||||
}
|
||||
|
||||
// Packet wrapper for streaming objects
|
||||
|
||||
232
web/src/hooks/useMultiModelChat.test.tsx
Normal file
232
web/src/hooks/useMultiModelChat.test.tsx
Normal file
@@ -0,0 +1,232 @@
|
||||
import { renderHook, act } from "@tests/setup/test-utils";
|
||||
import useMultiModelChat from "@/hooks/useMultiModelChat";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
|
||||
|
||||
// Mock buildLlmOptions — hook uses it internally for initialization.
|
||||
// Tests here focus on CRUD operations, not the initialization side-effect.
|
||||
jest.mock("@/refresh-components/popovers/LLMPopover", () => ({
|
||||
buildLlmOptions: jest.fn(() => []),
|
||||
}));
|
||||
|
||||
const makeLlmManager = (): LlmManager =>
|
||||
({
|
||||
llmProviders: [],
|
||||
currentLlm: { modelName: null, provider: null },
|
||||
isLoadingProviders: false,
|
||||
}) as unknown as LlmManager;
|
||||
|
||||
const makeModel = (provider: string, modelName: string): SelectedModel => ({
|
||||
name: provider,
|
||||
provider,
|
||||
modelName,
|
||||
displayName: `${provider}/${modelName}`,
|
||||
});
|
||||
|
||||
const GPT4 = makeModel("openai", "gpt-4");
|
||||
const CLAUDE = makeModel("anthropic", "claude-opus-4-6");
|
||||
const GEMINI = makeModel("google", "gemini-pro");
|
||||
const GPT4_TURBO = makeModel("openai", "gpt-4-turbo");
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// addModel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("addModel", () => {
|
||||
it("adds a model to an empty selection", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(1);
|
||||
expect(result.current.selectedModels[0]).toEqual(GPT4);
|
||||
});
|
||||
|
||||
it("does not add a duplicate model", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(GPT4); // duplicate
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("enforces MAX_MODELS (3) cap", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
result.current.addModel(GEMINI);
|
||||
result.current.addModel(GPT4_TURBO); // should be ignored
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(3);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// removeModel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("removeModel", () => {
|
||||
it("removes a model by index", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.removeModel(0); // remove GPT4
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(1);
|
||||
expect(result.current.selectedModels[0]).toEqual(CLAUDE);
|
||||
});
|
||||
|
||||
it("handles out-of-range index gracefully", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.removeModel(99); // no-op
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// replaceModel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("replaceModel", () => {
|
||||
it("replaces the model at the given index", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.replaceModel(0, GEMINI);
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels[0]).toEqual(GEMINI);
|
||||
expect(result.current.selectedModels[1]).toEqual(CLAUDE);
|
||||
});
|
||||
|
||||
it("does not replace with a model already selected at another index", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.replaceModel(0, CLAUDE); // CLAUDE is already at index 1
|
||||
});
|
||||
|
||||
// Should be a no-op — GPT4 stays at index 0
|
||||
expect(result.current.selectedModels[0]).toEqual(GPT4);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isMultiModelActive
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("isMultiModelActive", () => {
|
||||
it("is false with zero models", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
expect(result.current.isMultiModelActive).toBe(false);
|
||||
});
|
||||
|
||||
it("is false with exactly one model", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
});
|
||||
|
||||
expect(result.current.isMultiModelActive).toBe(false);
|
||||
});
|
||||
|
||||
it("is true with two or more models", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
expect(result.current.isMultiModelActive).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildLlmOverrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("buildLlmOverrides", () => {
|
||||
it("returns empty array when no models selected", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
expect(result.current.buildLlmOverrides()).toEqual([]);
|
||||
});
|
||||
|
||||
it("maps selectedModels to LLMOverride format", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
const overrides = result.current.buildLlmOverrides();
|
||||
|
||||
expect(overrides).toHaveLength(2);
|
||||
expect(overrides[0]).toEqual({
|
||||
model_provider: "openai",
|
||||
model_version: "gpt-4",
|
||||
display_name: "openai/gpt-4",
|
||||
});
|
||||
expect(overrides[1]).toEqual({
|
||||
model_provider: "anthropic",
|
||||
model_version: "claude-opus-4-6",
|
||||
display_name: "anthropic/claude-opus-4-6",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// clearModels
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("clearModels", () => {
|
||||
it("empties the selection", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.clearModels();
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(0);
|
||||
expect(result.current.isMultiModelActive).toBe(false);
|
||||
});
|
||||
});
|
||||
192
web/src/hooks/useMultiModelChat.ts
Normal file
192
web/src/hooks/useMultiModelChat.ts
Normal file
@@ -0,0 +1,192 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useCallback, useEffect, useMemo } from "react";
|
||||
import {
|
||||
MAX_MODELS,
|
||||
SelectedModel,
|
||||
} from "@/refresh-components/popovers/ModelSelector";
|
||||
import { LLMOverride } from "@/app/app/services/lib";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { buildLlmOptions } from "@/refresh-components/popovers/LLMPopover";
|
||||
|
||||
export interface UseMultiModelChatReturn {
|
||||
/** Currently selected models for multi-model comparison. */
|
||||
selectedModels: SelectedModel[];
|
||||
/** Whether multi-model mode is active (>1 model selected). */
|
||||
isMultiModelActive: boolean;
|
||||
/** Add a model to the selection. */
|
||||
addModel: (model: SelectedModel) => void;
|
||||
/** Remove a model by index. */
|
||||
removeModel: (index: number) => void;
|
||||
/** Replace a model at a specific index with a new one. */
|
||||
replaceModel: (index: number, model: SelectedModel) => void;
|
||||
/** Clear all selected models. */
|
||||
clearModels: () => void;
|
||||
/** Build the LLMOverride[] array from selectedModels. */
|
||||
buildLlmOverrides: () => LLMOverride[];
|
||||
/**
|
||||
* Restore multi-model selection from model version strings (e.g. from chat history).
|
||||
* Matches against available llmOptions to reconstruct full SelectedModel objects.
|
||||
*/
|
||||
restoreFromModelNames: (modelNames: string[]) => void;
|
||||
/**
|
||||
* Switch to a single model by name (after user picks a preferred response).
|
||||
* Matches against llmOptions to find the full SelectedModel.
|
||||
*/
|
||||
selectSingleModel: (modelName: string) => void;
|
||||
}
|
||||
|
||||
export default function useMultiModelChat(
|
||||
llmManager: LlmManager
|
||||
): UseMultiModelChatReturn {
|
||||
const [selectedModels, setSelectedModels] = useState<SelectedModel[]>([]);
|
||||
const [defaultInitialized, setDefaultInitialized] = useState(false);
|
||||
|
||||
// Initialize with the default model from llmManager once providers load
|
||||
const llmOptions = useMemo(
|
||||
() =>
|
||||
llmManager.llmProviders ? buildLlmOptions(llmManager.llmProviders) : [],
|
||||
[llmManager.llmProviders]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (defaultInitialized) return;
|
||||
if (llmOptions.length === 0) return;
|
||||
const { currentLlm } = llmManager;
|
||||
// Don't initialize if currentLlm hasn't loaded yet
|
||||
if (!currentLlm.modelName) return;
|
||||
const match = llmOptions.find(
|
||||
(opt) =>
|
||||
opt.provider === currentLlm.provider &&
|
||||
opt.modelName === currentLlm.modelName
|
||||
);
|
||||
if (match) {
|
||||
setSelectedModels([
|
||||
{
|
||||
name: match.name,
|
||||
provider: match.provider,
|
||||
modelName: match.modelName,
|
||||
displayName: match.displayName,
|
||||
},
|
||||
]);
|
||||
setDefaultInitialized(true);
|
||||
}
|
||||
}, [llmOptions, llmManager.currentLlm, defaultInitialized, llmManager]);
|
||||
|
||||
const isMultiModelActive = selectedModels.length > 1;
|
||||
|
||||
const addModel = useCallback((model: SelectedModel) => {
|
||||
setSelectedModels((prev) => {
|
||||
if (prev.length >= MAX_MODELS) return prev;
|
||||
if (
|
||||
prev.some(
|
||||
(m) =>
|
||||
m.provider === model.provider && m.modelName === model.modelName
|
||||
)
|
||||
) {
|
||||
return prev;
|
||||
}
|
||||
return [...prev, model];
|
||||
});
|
||||
}, []);
|
||||
|
||||
const removeModel = useCallback((index: number) => {
|
||||
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
|
||||
}, []);
|
||||
|
||||
const replaceModel = useCallback((index: number, model: SelectedModel) => {
|
||||
setSelectedModels((prev) => {
|
||||
// Don't replace with a model that's already selected elsewhere
|
||||
if (
|
||||
prev.some(
|
||||
(m, i) =>
|
||||
i !== index &&
|
||||
m.provider === model.provider &&
|
||||
m.modelName === model.modelName
|
||||
)
|
||||
) {
|
||||
return prev;
|
||||
}
|
||||
const next = [...prev];
|
||||
next[index] = model;
|
||||
return next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
const clearModels = useCallback(() => {
|
||||
setSelectedModels([]);
|
||||
}, []);
|
||||
|
||||
const restoreFromModelNames = useCallback(
|
||||
(modelNames: string[]) => {
|
||||
if (modelNames.length < 2 || llmOptions.length === 0) return;
|
||||
const restored: SelectedModel[] = [];
|
||||
for (const name of modelNames) {
|
||||
// Try matching by modelName (raw version string like "claude-opus-4-6")
|
||||
// or by displayName (friendly name like "Claude Opus 4.6")
|
||||
const match = llmOptions.find(
|
||||
(opt) =>
|
||||
opt.modelName === name ||
|
||||
opt.displayName === name ||
|
||||
opt.name === name
|
||||
);
|
||||
if (match) {
|
||||
restored.push({
|
||||
name: match.name,
|
||||
provider: match.provider,
|
||||
modelName: match.modelName,
|
||||
displayName: match.displayName,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (restored.length >= 2) {
|
||||
setSelectedModels(restored);
|
||||
setDefaultInitialized(true);
|
||||
}
|
||||
},
|
||||
[llmOptions]
|
||||
);
|
||||
|
||||
const selectSingleModel = useCallback(
|
||||
(modelName: string) => {
|
||||
if (llmOptions.length === 0) return;
|
||||
const match = llmOptions.find(
|
||||
(opt) =>
|
||||
opt.modelName === modelName ||
|
||||
opt.displayName === modelName ||
|
||||
opt.name === modelName
|
||||
);
|
||||
if (match) {
|
||||
setSelectedModels([
|
||||
{
|
||||
name: match.name,
|
||||
provider: match.provider,
|
||||
modelName: match.modelName,
|
||||
displayName: match.displayName,
|
||||
},
|
||||
]);
|
||||
}
|
||||
},
|
||||
[llmOptions]
|
||||
);
|
||||
|
||||
const buildLlmOverrides = useCallback((): LLMOverride[] => {
|
||||
return selectedModels.map((m) => ({
|
||||
model_provider: m.name,
|
||||
model_version: m.modelName,
|
||||
display_name: m.displayName,
|
||||
}));
|
||||
}, [selectedModels]);
|
||||
|
||||
return {
|
||||
selectedModels,
|
||||
isMultiModelActive,
|
||||
addModel,
|
||||
removeModel,
|
||||
replaceModel,
|
||||
clearModels,
|
||||
buildLlmOverrides,
|
||||
restoreFromModelNames,
|
||||
selectSingleModel,
|
||||
};
|
||||
}
|
||||
453
web/src/refresh-components/popovers/ModelSelector.tsx
Normal file
453
web/src/refresh-components/popovers/ModelSelector.tsx
Normal file
@@ -0,0 +1,453 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useMemo, useRef } from "react";
|
||||
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button } from "@opal/components";
|
||||
import {
|
||||
SvgCheck,
|
||||
SvgChevronDown,
|
||||
SvgChevronRight,
|
||||
SvgPlusCircle,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import {
|
||||
LLMOption,
|
||||
LLMOptionGroup,
|
||||
} from "@/refresh-components/popovers/interfaces";
|
||||
import {
|
||||
buildLlmOptions,
|
||||
groupLlmOptions,
|
||||
} from "@/refresh-components/popovers/LLMPopover";
|
||||
import * as AccordionPrimitive from "@radix-ui/react-accordion";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export const MAX_MODELS = 3;
|
||||
|
||||
export interface SelectedModel {
|
||||
name: string;
|
||||
provider: string;
|
||||
modelName: string;
|
||||
displayName: string;
|
||||
}
|
||||
|
||||
export interface ModelSelectorProps {
|
||||
llmManager: LlmManager;
|
||||
selectedModels: SelectedModel[];
|
||||
onAdd: (model: SelectedModel) => void;
|
||||
onRemove: (index: number) => void;
|
||||
onReplace: (index: number, model: SelectedModel) => void;
|
||||
}
|
||||
|
||||
/** Vertical 1px divider between model bar elements */
|
||||
function BarDivider() {
|
||||
return <div className="h-9 w-px bg-border-01 shrink-0" />;
|
||||
}
|
||||
|
||||
/** Individual model pill in the model bar */
|
||||
function ModelPill({
|
||||
model,
|
||||
isMultiModel,
|
||||
onRemove,
|
||||
onClick,
|
||||
}: {
|
||||
model: SelectedModel;
|
||||
isMultiModel: boolean;
|
||||
onRemove?: () => void;
|
||||
onClick?: () => void;
|
||||
}) {
|
||||
const ProviderIcon = getProviderIcon(model.provider, model.modelName);
|
||||
|
||||
return (
|
||||
<div
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onClick={onClick}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" || e.key === " ") {
|
||||
e.preventDefault();
|
||||
onClick?.();
|
||||
}
|
||||
}}
|
||||
className={cn(
|
||||
"flex items-center gap-0.5 rounded-12 p-2 shrink-0 cursor-pointer",
|
||||
"hover:bg-background-tint-02 transition-colors",
|
||||
isMultiModel && "bg-background-tint-02"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center justify-center size-5 shrink-0 p-0.5">
|
||||
<ProviderIcon size={16} />
|
||||
</div>
|
||||
<Text mainUiAction text04 nowrap className="px-1">
|
||||
{model.displayName}
|
||||
</Text>
|
||||
{isMultiModel ? (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgX}
|
||||
size="2xs"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.();
|
||||
}}
|
||||
tooltip="Remove model"
|
||||
/>
|
||||
) : (
|
||||
<SvgChevronDown className="size-4 stroke-text-03 shrink-0" />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/** Model item row inside the add-model popover */
|
||||
function ModelItem({
|
||||
option,
|
||||
isSelected,
|
||||
isDisabled,
|
||||
onToggle,
|
||||
}: {
|
||||
option: LLMOption;
|
||||
isSelected: boolean;
|
||||
isDisabled: boolean;
|
||||
onToggle: () => void;
|
||||
}) {
|
||||
const ProviderIcon = getProviderIcon(option.provider, option.modelName);
|
||||
|
||||
// Build subtitle from model capabilities
|
||||
const subtitle = useMemo(() => {
|
||||
const parts: string[] = [];
|
||||
if (option.supportsReasoning) parts.push("reasoning");
|
||||
if (option.supportsImageInput) parts.push("multi-modal");
|
||||
if (parts.length === 0 && option.modelName) return option.modelName;
|
||||
return parts.join(", ");
|
||||
}, [option]);
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
disabled={isDisabled}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"flex items-center gap-1.5 w-full rounded-08 p-1.5 text-left transition-colors",
|
||||
isSelected ? "bg-action-link-01" : "hover:bg-background-tint-02",
|
||||
isDisabled && !isSelected && "opacity-50 cursor-not-allowed"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center justify-center size-5 shrink-0 p-0.5">
|
||||
<ProviderIcon size={16} />
|
||||
</div>
|
||||
<div className="flex flex-col flex-1 min-w-0">
|
||||
<Text
|
||||
mainUiAction
|
||||
nowrap
|
||||
className={cn(isSelected ? "text-action-link-03" : "text-text-04")}
|
||||
>
|
||||
{option.displayName}
|
||||
</Text>
|
||||
{subtitle && (
|
||||
<Text secondaryBody text03 nowrap>
|
||||
{subtitle}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
{isSelected && (
|
||||
<Text secondaryBody nowrap className="text-action-link-05 shrink-0">
|
||||
Added
|
||||
</Text>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
export default function ModelSelector({
|
||||
llmManager,
|
||||
selectedModels,
|
||||
onAdd,
|
||||
onRemove,
|
||||
onReplace,
|
||||
}: ModelSelectorProps) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const scrollContainerRef = useRef<HTMLDivElement>(null);
|
||||
// null = add mode (via + button), number = replace mode (via pill click)
|
||||
const [replacingIndex, setReplacingIndex] = useState<number | null>(null);
|
||||
|
||||
const isMultiModel = selectedModels.length > 1;
|
||||
const atMax = selectedModels.length >= MAX_MODELS;
|
||||
|
||||
const llmOptions = useMemo(
|
||||
() => buildLlmOptions(llmManager.llmProviders),
|
||||
[llmManager.llmProviders]
|
||||
);
|
||||
|
||||
const selectedKeys = useMemo(
|
||||
() => new Set(selectedModels.map((m) => `${m.provider}:${m.modelName}`)),
|
||||
[selectedModels]
|
||||
);
|
||||
|
||||
const filteredOptions = useMemo(() => {
|
||||
if (!searchQuery.trim()) return llmOptions;
|
||||
const query = searchQuery.toLowerCase();
|
||||
return llmOptions.filter(
|
||||
(opt) =>
|
||||
opt.displayName.toLowerCase().includes(query) ||
|
||||
opt.modelName.toLowerCase().includes(query) ||
|
||||
(opt.vendor && opt.vendor.toLowerCase().includes(query))
|
||||
);
|
||||
}, [llmOptions, searchQuery]);
|
||||
|
||||
const groupedOptions = useMemo(
|
||||
() => groupLlmOptions(filteredOptions),
|
||||
[filteredOptions]
|
||||
);
|
||||
|
||||
const isSearching = searchQuery.trim().length > 0;
|
||||
|
||||
// In replace mode, other selected models (not the one being replaced) are disabled
|
||||
const otherSelectedKeys = useMemo(() => {
|
||||
if (replacingIndex === null) return new Set<string>();
|
||||
return new Set(
|
||||
selectedModels
|
||||
.filter((_, i) => i !== replacingIndex)
|
||||
.map((m) => `${m.provider}:${m.modelName}`)
|
||||
);
|
||||
}, [selectedModels, replacingIndex]);
|
||||
|
||||
// Current model at the replacing index (shows as "selected" in replace mode)
|
||||
const replacingKey = useMemo(() => {
|
||||
if (replacingIndex === null) return null;
|
||||
const m = selectedModels[replacingIndex];
|
||||
return m ? `${m.provider}:${m.modelName}` : null;
|
||||
}, [selectedModels, replacingIndex]);
|
||||
|
||||
const getItemState = (optKey: string) => {
|
||||
if (replacingIndex !== null) {
|
||||
// Replace mode
|
||||
return {
|
||||
isSelected: optKey === replacingKey,
|
||||
isDisabled: otherSelectedKeys.has(optKey),
|
||||
};
|
||||
}
|
||||
// Add mode
|
||||
return {
|
||||
isSelected: selectedKeys.has(optKey),
|
||||
isDisabled: !selectedKeys.has(optKey) && atMax,
|
||||
};
|
||||
};
|
||||
|
||||
const handleSelectModel = (option: LLMOption) => {
|
||||
const model: SelectedModel = {
|
||||
name: option.name,
|
||||
provider: option.provider,
|
||||
modelName: option.modelName,
|
||||
displayName: option.displayName,
|
||||
};
|
||||
|
||||
if (replacingIndex !== null) {
|
||||
// Replace mode: swap the model at the clicked pill's index
|
||||
onReplace(replacingIndex, model);
|
||||
setOpen(false);
|
||||
setReplacingIndex(null);
|
||||
setSearchQuery("");
|
||||
return;
|
||||
}
|
||||
|
||||
// Add mode: toggle (add/remove)
|
||||
const key = `${option.provider}:${option.modelName}`;
|
||||
const existingIndex = selectedModels.findIndex(
|
||||
(m) => `${m.provider}:${m.modelName}` === key
|
||||
);
|
||||
if (existingIndex >= 0) {
|
||||
onRemove(existingIndex);
|
||||
} else if (!atMax) {
|
||||
onAdd(model);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOpenChange = (nextOpen: boolean) => {
|
||||
setOpen(nextOpen);
|
||||
if (!nextOpen) {
|
||||
setReplacingIndex(null);
|
||||
setSearchQuery("");
|
||||
}
|
||||
};
|
||||
|
||||
const handlePillClick = (index: number) => {
|
||||
setReplacingIndex(index);
|
||||
setOpen(true);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-end gap-1 p-1">
|
||||
{/* (+) Add model button — hidden at max models */}
|
||||
{!atMax && (
|
||||
<Popover open={open} onOpenChange={handleOpenChange}>
|
||||
<Popover.Trigger asChild>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgPlusCircle}
|
||||
size="sm"
|
||||
tooltip="Add Model"
|
||||
/>
|
||||
</Popover.Trigger>
|
||||
|
||||
<Popover.Content side="top" align="start" width="lg">
|
||||
<Section gap={0.25}>
|
||||
<InputTypeIn
|
||||
leftSearchIcon
|
||||
variant="internal"
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
placeholder="Search models..."
|
||||
/>
|
||||
|
||||
<PopoverMenu scrollContainerRef={scrollContainerRef}>
|
||||
{groupedOptions.length === 0
|
||||
? [
|
||||
<div key="empty" className="py-3 px-2">
|
||||
<Text secondaryBody text03>
|
||||
No models found
|
||||
</Text>
|
||||
</div>,
|
||||
]
|
||||
: groupedOptions.length === 1
|
||||
? [
|
||||
<div key="single" className="flex flex-col gap-0.5">
|
||||
{groupedOptions[0]!.options.map((opt) => {
|
||||
const key = `${opt.provider}:${opt.modelName}`;
|
||||
const state = getItemState(key);
|
||||
return (
|
||||
<ModelItem
|
||||
key={opt.modelName}
|
||||
option={opt}
|
||||
isSelected={state.isSelected}
|
||||
isDisabled={state.isDisabled}
|
||||
onToggle={() => handleSelectModel(opt)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>,
|
||||
]
|
||||
: [
|
||||
<ModelGroupAccordion
|
||||
key="accordion"
|
||||
groups={groupedOptions}
|
||||
isSearching={isSearching}
|
||||
getItemState={getItemState}
|
||||
onToggle={handleSelectModel}
|
||||
/>,
|
||||
]}
|
||||
</PopoverMenu>
|
||||
</Section>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
)}
|
||||
|
||||
{/* Divider + model pills */}
|
||||
{selectedModels.length > 0 && (
|
||||
<>
|
||||
<BarDivider />
|
||||
{selectedModels.map((model, index) => (
|
||||
<div
|
||||
key={`${model.provider}:${model.modelName}`}
|
||||
className="flex items-center gap-1"
|
||||
>
|
||||
{index > 0 && <BarDivider />}
|
||||
<ModelPill
|
||||
model={model}
|
||||
isMultiModel={isMultiModel}
|
||||
onRemove={() => onRemove(index)}
|
||||
onClick={() => handlePillClick(index)}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ModelGroupAccordionProps {
|
||||
groups: LLMOptionGroup[];
|
||||
isSearching: boolean;
|
||||
getItemState: (key: string) => { isSelected: boolean; isDisabled: boolean };
|
||||
onToggle: (option: LLMOption) => void;
|
||||
}
|
||||
|
||||
function ModelGroupAccordion({
|
||||
groups,
|
||||
isSearching,
|
||||
getItemState,
|
||||
onToggle,
|
||||
}: ModelGroupAccordionProps) {
|
||||
const allKeys = groups.map((g) => g.key);
|
||||
const [expandedGroups, setExpandedGroups] = useState<string[]>([
|
||||
allKeys[0] ?? "",
|
||||
]);
|
||||
|
||||
const effectiveExpanded = isSearching ? allKeys : expandedGroups;
|
||||
|
||||
return (
|
||||
<AccordionPrimitive.Root
|
||||
type="multiple"
|
||||
value={effectiveExpanded}
|
||||
onValueChange={(value) => {
|
||||
if (!isSearching) setExpandedGroups(value);
|
||||
}}
|
||||
className="w-full flex flex-col"
|
||||
>
|
||||
{groups.map((group) => {
|
||||
const isExpanded = effectiveExpanded.includes(group.key);
|
||||
return (
|
||||
<AccordionPrimitive.Item
|
||||
key={group.key}
|
||||
value={group.key}
|
||||
className="pt-1"
|
||||
>
|
||||
<AccordionPrimitive.Header className="flex">
|
||||
<AccordionPrimitive.Trigger className="flex items-center rounded-08 hover:bg-background-tint-02 w-full py-1">
|
||||
<div className="flex items-center gap-1 shrink-0">
|
||||
<div className="flex items-center justify-center size-5 shrink-0">
|
||||
<group.Icon size={16} />
|
||||
</div>
|
||||
<Text secondaryBody text03 nowrap className="px-0.5">
|
||||
{group.displayName}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="flex-1" />
|
||||
<div className="flex items-center justify-center size-6 shrink-0">
|
||||
{isExpanded ? (
|
||||
<SvgChevronDown className="h-4 w-4 stroke-text-04 shrink-0" />
|
||||
) : (
|
||||
<SvgChevronRight className="h-4 w-4 stroke-text-04 shrink-0" />
|
||||
)}
|
||||
</div>
|
||||
</AccordionPrimitive.Trigger>
|
||||
</AccordionPrimitive.Header>
|
||||
<AccordionPrimitive.Content className="overflow-hidden data-[state=closed]:animate-accordion-up data-[state=open]:animate-accordion-down">
|
||||
<div className="flex flex-col gap-0.5 pt-0 pb-0">
|
||||
{group.options.map((opt) => {
|
||||
const key = `${opt.provider}:${opt.modelName}`;
|
||||
const state = getItemState(key);
|
||||
return (
|
||||
<ModelItem
|
||||
key={key}
|
||||
option={opt}
|
||||
isSelected={state.isSelected}
|
||||
isDisabled={state.isDisabled}
|
||||
onToggle={() => onToggle(opt)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</AccordionPrimitive.Content>
|
||||
</AccordionPrimitive.Item>
|
||||
);
|
||||
})}
|
||||
</AccordionPrimitive.Root>
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user