mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-05 15:02:43 +00:00
Compare commits
16 Commits
cli/v0.1.2
...
add-multip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
931988145d | ||
|
|
9dc554f2a1 | ||
|
|
647addb026 | ||
|
|
d2c6150f86 | ||
|
|
ec571c0ea5 | ||
|
|
7a437ebb81 | ||
|
|
cf5f7e0936 | ||
|
|
6c299ccec5 | ||
|
|
1fc4c3a930 | ||
|
|
89d6e04938 | ||
|
|
8bf3559bbf | ||
|
|
1f9024c0b3 | ||
|
|
9f0bbf0e17 | ||
|
|
0aef537ba5 | ||
|
|
bd918980a8 | ||
|
|
e6e42fdbf6 |
@@ -0,0 +1,36 @@
|
||||
"""add multi-modal response support to chat_message
|
||||
|
||||
Revision ID: 34ef1e82a4fa
|
||||
Revises: e8f0d2a38171
|
||||
Create Date: 2025-12-04 14:53:05.821715
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "34ef1e82a4fa"
|
||||
down_revision = "e8f0d2a38171"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add model_provider column to track which LLM provider generated the response
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("model_provider", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
# Add model_name column to track which specific model generated the response
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("model_name", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the columns in reverse order
|
||||
op.drop_column("chat_message", "model_name")
|
||||
op.drop_column("chat_message", "model_provider")
|
||||
@@ -97,6 +97,8 @@ class OnyxAnswerPiece(BaseModel):
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
model_provider: str | None = None
|
||||
model_name: str | None = None
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
|
||||
277
backend/onyx/chat/multi_model_stream.py
Normal file
277
backend/onyx/chat/multi_model_stream.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
Multi-model streaming infrastructure for concurrent LLM execution.
|
||||
|
||||
This module provides classes for running multiple LLM models concurrently
|
||||
and merging their streaming outputs into a single stream.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from queue import Empty
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelStreamContext:
|
||||
"""Context for a single model's streaming execution."""
|
||||
|
||||
model_id: str
|
||||
emitter: "ModelTaggingEmitter"
|
||||
state_container: ChatStateContainer
|
||||
thread: threading.Thread | None = None
|
||||
completed: bool = False
|
||||
error: Exception | None = None
|
||||
# Additional data that may be needed for saving chat turns
|
||||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelTaggingEmitter(Emitter):
|
||||
"""Emitter that tags packets with model_id and forwards to a merged queue.
|
||||
|
||||
This emitter wraps the standard Emitter to add model identification to each
|
||||
packet, enabling the frontend to route packets to the correct model's UI.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
merged_queue: Queue[Packet | None],
|
||||
merger: "MultiModelStreamMerger",
|
||||
):
|
||||
# Create a local bus for compatibility with existing code that may
|
||||
# access emitter.bus directly
|
||||
super().__init__(bus=Queue())
|
||||
self.model_id = model_id
|
||||
self.merged_queue = merged_queue
|
||||
self.merger = merger
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
"""Emit a packet tagged with this emitter's model_id.
|
||||
|
||||
The packet is forwarded to the merged queue for interleaved streaming.
|
||||
"""
|
||||
# Create a new packet with the model_id set
|
||||
tagged_packet = Packet(
|
||||
turn_index=packet.turn_index,
|
||||
obj=packet.obj,
|
||||
model_id=self.model_id,
|
||||
)
|
||||
|
||||
# Forward to merged queue for interleaved streaming
|
||||
self.merged_queue.put(tagged_packet)
|
||||
|
||||
# Check for completion signals
|
||||
if isinstance(packet.obj, OverallStop):
|
||||
self.merger.mark_model_complete(self.model_id)
|
||||
|
||||
|
||||
class MultiModelStreamMerger:
|
||||
"""Merges packet streams from multiple concurrent LLM executions.
|
||||
|
||||
This class manages the concurrent execution of multiple LLM models and
|
||||
merges their streaming outputs into a single stream. Each model runs in
|
||||
its own thread and emits packets to a shared queue.
|
||||
|
||||
Usage:
|
||||
merger = MultiModelStreamMerger()
|
||||
|
||||
# Register models and get their emitters
|
||||
for model in models:
|
||||
ctx = merger.register_model(model_id)
|
||||
# Start thread with ctx.emitter
|
||||
|
||||
# Stream merged packets
|
||||
for packet in merger.stream(is_connected):
|
||||
yield packet
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.merged_queue: Queue[Packet | None] = Queue()
|
||||
self.model_contexts: dict[str, ModelStreamContext] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._completed_count = 0
|
||||
self._total_models = 0
|
||||
self._all_complete = threading.Event()
|
||||
|
||||
def register_model(self, model_id: str) -> ModelStreamContext:
|
||||
"""Register a model for concurrent streaming.
|
||||
|
||||
Args:
|
||||
model_id: Unique identifier for the model (e.g., "openai:gpt-4")
|
||||
|
||||
Returns:
|
||||
ModelStreamContext containing the emitter to use for this model.
|
||||
"""
|
||||
with self._lock:
|
||||
if model_id in self.model_contexts:
|
||||
raise ValueError(f"Model {model_id} already registered")
|
||||
|
||||
emitter = ModelTaggingEmitter(
|
||||
model_id=model_id,
|
||||
merged_queue=self.merged_queue,
|
||||
merger=self,
|
||||
)
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
context = ModelStreamContext(
|
||||
model_id=model_id,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
)
|
||||
self.model_contexts[model_id] = context
|
||||
self._total_models += 1
|
||||
|
||||
return context
|
||||
|
||||
def mark_model_complete(
|
||||
self, model_id: str, error: Exception | None = None
|
||||
) -> None:
|
||||
"""Mark a model's stream as complete.
|
||||
|
||||
Called automatically when an OverallStop packet is emitted,
|
||||
or can be called manually on error.
|
||||
|
||||
Args:
|
||||
model_id: The model that has completed
|
||||
error: Optional exception if the model failed
|
||||
"""
|
||||
with self._lock:
|
||||
if model_id in self.model_contexts:
|
||||
ctx = self.model_contexts[model_id]
|
||||
if not ctx.completed:
|
||||
ctx.completed = True
|
||||
ctx.error = error
|
||||
self._completed_count += 1
|
||||
logger.debug(
|
||||
f"Model {model_id} completed "
|
||||
f"({self._completed_count}/{self._total_models})"
|
||||
)
|
||||
|
||||
if self._completed_count >= self._total_models:
|
||||
# All models complete - send sentinel to unblock stream
|
||||
self.merged_queue.put(None)
|
||||
self._all_complete.set()
|
||||
|
||||
def start_model_thread(
|
||||
self,
|
||||
model_id: str,
|
||||
target_func: Callable[..., None],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a background thread for a model's LLM execution.
|
||||
|
||||
Args:
|
||||
model_id: The model to run
|
||||
target_func: The function to run (should accept emitter as first arg)
|
||||
*args: Additional positional arguments for target_func
|
||||
**kwargs: Additional keyword arguments for target_func
|
||||
"""
|
||||
ctx = self.model_contexts.get(model_id)
|
||||
if ctx is None:
|
||||
raise ValueError(f"Model {model_id} not registered")
|
||||
|
||||
# Copy context vars for the new thread (important for tenant_id, etc.)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
def thread_target() -> None:
|
||||
try:
|
||||
context.run(
|
||||
target_func,
|
||||
ctx.emitter,
|
||||
*args,
|
||||
state_container=ctx.state_container,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Model {model_id} failed: {e}")
|
||||
# Emit error packet
|
||||
ctx.emitter.emit(
|
||||
Packet(
|
||||
turn_index=0,
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
self.mark_model_complete(model_id, error=e)
|
||||
|
||||
thread = threading.Thread(target=thread_target, daemon=True)
|
||||
ctx.thread = thread
|
||||
thread.start()
|
||||
|
||||
def stream(
|
||||
self,
|
||||
is_connected: Callable[[], bool],
|
||||
poll_timeout: float = 0.3,
|
||||
) -> Iterator[Packet]:
|
||||
"""Yield merged packets from all models.
|
||||
|
||||
This generator yields packets from all models as they arrive,
|
||||
checking the stop signal periodically.
|
||||
|
||||
Args:
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
poll_timeout: Timeout for queue polling (seconds)
|
||||
|
||||
Yields:
|
||||
Packet objects with model_id set for routing
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
packet = self.merged_queue.get(timeout=poll_timeout)
|
||||
|
||||
if packet is None:
|
||||
# All models complete
|
||||
break
|
||||
|
||||
# Check for exception packets - but don't raise, just yield
|
||||
# This allows other models to continue streaming
|
||||
if isinstance(packet.obj, PacketException):
|
||||
# Log the error but continue streaming other models
|
||||
logger.error(
|
||||
f"Error from model {packet.model_id}: {packet.obj.exception}"
|
||||
)
|
||||
# Still yield the error packet so frontend can show it
|
||||
yield packet
|
||||
continue
|
||||
|
||||
# Check for OverallStop - this indicates one model finished
|
||||
# Don't break here, wait for all models
|
||||
yield packet
|
||||
|
||||
except Empty:
|
||||
# Check stop signal
|
||||
if not is_connected():
|
||||
logger.debug("Stop signal detected, stopping all models")
|
||||
break
|
||||
|
||||
# Check if all models completed (defensive)
|
||||
if self._all_complete.is_set():
|
||||
break
|
||||
|
||||
def wait_for_threads(self, timeout: float | None = None) -> None:
|
||||
"""Wait for all model threads to complete.
|
||||
|
||||
Args:
|
||||
timeout: Optional timeout per thread (seconds)
|
||||
"""
|
||||
for ctx in self.model_contexts.values():
|
||||
if ctx.thread is not None and ctx.thread.is_alive():
|
||||
ctx.thread.join(timeout=timeout)
|
||||
|
||||
def get_model_contexts(self) -> dict[str, ModelStreamContext]:
|
||||
"""Get all model contexts for post-processing (e.g., saving chat turns)."""
|
||||
return self.model_contexts.copy()
|
||||
@@ -22,6 +22,7 @@ from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.multi_model_stream import MultiModelStreamMerger
|
||||
from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
from onyx.chat.save_chat import save_chat_turn
|
||||
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
|
||||
@@ -51,6 +52,7 @@ from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -321,10 +323,18 @@ def stream_chat_message_objects(
|
||||
"Must specify a set of documents for chat or specify search options"
|
||||
)
|
||||
|
||||
# Build the list of LLM overrides to process (multi-model support)
|
||||
llm_overrides: list[LLMOverride | None] = []
|
||||
if new_msg_req.llm_overrides and len(new_msg_req.llm_overrides) > 0:
|
||||
llm_overrides = list(new_msg_req.llm_overrides)
|
||||
else:
|
||||
llm_overrides = [new_msg_req.llm_override or chat_session.llm_override]
|
||||
|
||||
# Use first model for initial setup (token counting for user message)
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
llm_override=llm_overrides[0],
|
||||
additional_headers=litellm_additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
@@ -372,104 +382,14 @@ def stream_chat_message_objects(
|
||||
|
||||
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
reserved_token_count = calculate_reserved_tokens(
|
||||
db_session=db_session,
|
||||
persona_system_prompt=custom_agent_prompt or "",
|
||||
token_counter=token_counter,
|
||||
files=last_chat_message.files,
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
extracted_project_files = _extract_project_file_texts_and_images(
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
llm_max_context_window=llm.config.max_input_tokens,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# There are cases where the internal search tool should be disabled
|
||||
# If the user is in a project, it should not use other sources / generic search
|
||||
# If they are in a project but using a custom agent, it should use the agent setup
|
||||
# (which means it can use search)
|
||||
# However if in a project and there are more files than can fit in the context,
|
||||
# it should use the search tool with the project filter on
|
||||
disable_internal_search = bool(
|
||||
chat_session.project_id
|
||||
and persona.id is DEFAULT_PERSONA_ID
|
||||
and (
|
||||
extracted_project_files.project_file_texts
|
||||
or not extracted_project_files.project_as_filter
|
||||
)
|
||||
)
|
||||
|
||||
emitter = get_default_emitter()
|
||||
|
||||
# Construct tools based on the persona configurations
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=user_selected_filters,
|
||||
project_id=(
|
||||
chat_session.project_id
|
||||
if extracted_project_files.project_as_filter
|
||||
else None
|
||||
),
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
disable_internal_search=disable_internal_search,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# TODO Once summarization is done, we don't need to load all the files from the beginning anymore.
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(chat_history, db_session)
|
||||
|
||||
# TODO Need to think of some way to support selected docs from the sidebar
|
||||
|
||||
# Reserve a message id for the assistant response for frontend to track packets
|
||||
assistant_response = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id,
|
||||
reserved_assistant_message_id=assistant_response.id,
|
||||
)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
# Convert the chat history into a simple format that is free of any DB objects
|
||||
# and is easy to parse for the agent loop
|
||||
simple_chat_history = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
project_image_files=extracted_project_files.project_image_files,
|
||||
additional_context=additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
)
|
||||
|
||||
redis_client = get_redis_client()
|
||||
|
||||
reset_cancel_status(
|
||||
@@ -480,76 +400,242 @@ def stream_chat_message_objects(
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session_id, redis_client)
|
||||
|
||||
# Create state container for accumulating partial results
|
||||
state_container = ChatStateContainer()
|
||||
# Setup for multi-model (concurrent) vs single-model (synchronous)
|
||||
is_multi_model = len(llm_overrides) > 1
|
||||
merger = MultiModelStreamMerger() if is_multi_model else None
|
||||
|
||||
# Run the LLM loop with explicit wrapper for stop signal handling
|
||||
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
|
||||
# for stop signals. run_llm_loop itself doesn't know about stopping.
|
||||
# Note: DB session is not thread safe but nothing else uses it and the
|
||||
# reference is passed directly so it's ok.
|
||||
yield from run_chat_llm_with_state_containers(
|
||||
run_llm_loop,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
memories=memories,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=(
|
||||
new_msg_req.forced_tool_ids[0] if new_msg_req.forced_tool_ids else None
|
||||
),
|
||||
)
|
||||
# Track state for single-model case (used after the loop)
|
||||
single_model_state: tuple[ChatStateContainer, ChatMessage, LLM] | None = None
|
||||
|
||||
# Process each model
|
||||
for llm_override in llm_overrides:
|
||||
# Get LLM for this model
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
llm_override=llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
token_counter = get_llm_token_counter(llm)
|
||||
|
||||
reserved_token_count = calculate_reserved_tokens(
|
||||
db_session=db_session,
|
||||
persona_system_prompt=custom_agent_prompt or "",
|
||||
token_counter=token_counter,
|
||||
files=last_chat_message.files,
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
extracted_project_files = _extract_project_file_texts_and_images(
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
llm_max_context_window=llm.config.max_input_tokens,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# There are cases where the internal search tool should be disabled
|
||||
# If the user is in a project, it should not use other sources / generic search
|
||||
# If they are in a project but using a custom agent, it should use the agent setup
|
||||
# (which means it can use search)
|
||||
# However if in a project and there are more files than can fit in the context,
|
||||
# it should use the search tool with the project filter on
|
||||
disable_internal_search = bool(
|
||||
chat_session.project_id
|
||||
and persona.id is DEFAULT_PERSONA_ID
|
||||
and (
|
||||
extracted_project_files.project_file_texts
|
||||
or not extracted_project_files.project_as_filter
|
||||
)
|
||||
)
|
||||
|
||||
# Get the appropriate emitter
|
||||
if merger is not None:
|
||||
model_id = f"{llm.config.model_provider}:{llm.config.model_name}"
|
||||
model_ctx = merger.register_model(model_id)
|
||||
emitter = model_ctx.emitter
|
||||
else:
|
||||
emitter = get_default_emitter()
|
||||
# Construct tools based on the persona configurations
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=user_selected_filters,
|
||||
project_id=(
|
||||
chat_session.project_id
|
||||
if extracted_project_files.project_as_filter
|
||||
else None
|
||||
),
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
disable_internal_search=disable_internal_search,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# Reserve a message id for the assistant response for frontend to track packets
|
||||
assistant_response = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name=llm.config.model_name,
|
||||
)
|
||||
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id,
|
||||
reserved_assistant_message_id=assistant_response.id,
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name=llm.config.model_name,
|
||||
)
|
||||
|
||||
# Convert the chat history into a simple format that is free of any DB objects
|
||||
# and is easy to parse for the agent loop
|
||||
simple_chat_history = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
project_image_files=extracted_project_files.project_image_files,
|
||||
additional_context=additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
)
|
||||
|
||||
if merger is not None:
|
||||
# Multi-model: start thread for concurrent execution
|
||||
model_ctx.extra_data["assistant_response"] = assistant_response
|
||||
model_ctx.extra_data["llm"] = llm
|
||||
|
||||
merger.start_model_thread(
|
||||
model_id=model_id,
|
||||
target_func=run_llm_loop,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
memories=memories,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=(
|
||||
new_msg_req.forced_tool_ids[0]
|
||||
if new_msg_req.forced_tool_ids
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Single model: run synchronously
|
||||
state_container = ChatStateContainer()
|
||||
single_model_state = (state_container, assistant_response, llm)
|
||||
|
||||
# Run the LLM loop with explicit wrapper for stop signal handling
|
||||
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
|
||||
# for stop signals. run_llm_loop itself doesn't know about stopping.
|
||||
# Note: DB session is not thread safe but nothing else uses it and the
|
||||
# reference is passed directly so it's ok.
|
||||
yield from run_chat_llm_with_state_containers(
|
||||
run_llm_loop,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
memories=memories,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=(
|
||||
new_msg_req.forced_tool_ids[0]
|
||||
if new_msg_req.forced_tool_ids
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# After loop: handle streaming for multi-model
|
||||
if merger is not None:
|
||||
yield from merger.stream(is_connected=check_is_connected)
|
||||
merger.wait_for_threads()
|
||||
|
||||
# Determine if stopped by user
|
||||
completed_normally = check_is_connected()
|
||||
if not completed_normally:
|
||||
logger.debug(f"Chat session {chat_session_id} stopped by user")
|
||||
|
||||
# Build final answer based on completion status
|
||||
if completed_normally:
|
||||
if state_container.answer_tokens is None:
|
||||
raise RuntimeError(
|
||||
"LLM run completed normally but did not return an answer."
|
||||
)
|
||||
final_answer = state_container.answer_tokens
|
||||
# Build list of (state_container, assistant_response, error) for saving
|
||||
if merger is not None:
|
||||
models_to_save = [
|
||||
(ctx.state_container, ctx.extra_data["assistant_response"], ctx.error)
|
||||
for ctx in merger.get_model_contexts().values()
|
||||
]
|
||||
else:
|
||||
# Stopped by user - append stop message
|
||||
if state_container.answer_tokens:
|
||||
final_answer = (
|
||||
state_container.answer_tokens
|
||||
+ " ... The generation was stopped by the user here."
|
||||
)
|
||||
assert single_model_state is not None
|
||||
state_container, assistant_response, _ = single_model_state
|
||||
models_to_save = [(state_container, assistant_response, None)]
|
||||
|
||||
# Save chat turns for all models
|
||||
for state_container, assistant_response, error in models_to_save:
|
||||
if error is not None:
|
||||
final_answer = f"Error: {str(error)}"
|
||||
elif completed_normally:
|
||||
if state_container.answer_tokens is None:
|
||||
if merger is not None:
|
||||
final_answer = "No response generated."
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"LLM run completed normally but did not return an answer."
|
||||
)
|
||||
else:
|
||||
final_answer = state_container.answer_tokens
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
# Build citation_docs_info from accumulated citations in state container
|
||||
citation_docs_info: list[CitationDocInfo] = []
|
||||
seen_citation_nums: set[int] = set()
|
||||
for citation_num, search_doc in state_container.citation_to_doc.items():
|
||||
if citation_num not in seen_citation_nums:
|
||||
seen_citation_nums.add(citation_num)
|
||||
citation_docs_info.append(
|
||||
CitationDocInfo(
|
||||
search_doc=search_doc,
|
||||
citation_number=citation_num,
|
||||
# Stopped by user - append stop message
|
||||
if state_container.answer_tokens:
|
||||
final_answer = (
|
||||
state_container.answer_tokens
|
||||
+ " ... The generation was stopped by the user here."
|
||||
)
|
||||
)
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_docs_info=citation_docs_info,
|
||||
tool_calls=state_container.tool_calls,
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_response,
|
||||
)
|
||||
# Build citation_docs_info from accumulated citations
|
||||
citation_docs_info: list[CitationDocInfo] = []
|
||||
seen_citation_nums: set[int] = set()
|
||||
for citation_num, search_doc in state_container.citation_to_doc.items():
|
||||
if citation_num not in seen_citation_nums:
|
||||
seen_citation_nums.add(citation_num)
|
||||
citation_docs_info.append(
|
||||
CitationDocInfo(
|
||||
search_doc=search_doc,
|
||||
citation_number=citation_num,
|
||||
)
|
||||
)
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_docs_info=citation_docs_info,
|
||||
tool_calls=state_container.tool_calls,
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_response,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
@@ -620,6 +620,8 @@ def reserve_message_id(
|
||||
chat_session_id: UUID,
|
||||
parent_message: int,
|
||||
message_type: MessageType = MessageType.ASSISTANT,
|
||||
model_provider: str | None = None,
|
||||
model_name: str | None = None,
|
||||
) -> ChatMessage:
|
||||
# Create an temporary holding chat message to the updated and saved at the end
|
||||
empty_message = ChatMessage(
|
||||
@@ -629,6 +631,8 @@ def reserve_message_id(
|
||||
message="Response was termination prior to completion, try regenerating.",
|
||||
token_count=15,
|
||||
message_type=message_type,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
# Add the empty message to the session
|
||||
@@ -661,6 +665,8 @@ def create_new_chat_message(
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
reasoning_tokens: str | None = None,
|
||||
model_provider: str | None = None,
|
||||
model_name: str | None = None,
|
||||
) -> ChatMessage:
|
||||
if reserved_message_id is not None:
|
||||
# Edit existing message
|
||||
@@ -676,6 +682,8 @@ def create_new_chat_message(
|
||||
existing_message.files = files
|
||||
existing_message.error = error
|
||||
existing_message.reasoning_tokens = reasoning_tokens
|
||||
existing_message.model_provider = model_provider
|
||||
existing_message.model_name = model_name
|
||||
new_chat_message = existing_message
|
||||
else:
|
||||
# Create new message
|
||||
@@ -689,6 +697,8 @@ def create_new_chat_message(
|
||||
files=files,
|
||||
error=error,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
)
|
||||
db_session.add(new_chat_message)
|
||||
|
||||
@@ -874,6 +884,8 @@ def translate_db_message_to_chat_message_detail(
|
||||
files=chat_message.files or [],
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
model_provider=chat_message.model_provider,
|
||||
model_name=chat_message.model_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -2142,6 +2142,9 @@ class ChatMessage(Base):
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
model_provider: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||
|
||||
|
||||
@@ -109,6 +109,11 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
llm_override: LLMOverride | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
|
||||
# List of LLM overrides to generate responses from
|
||||
# If provided, generates one response per override (all sharing the same parent message)
|
||||
# Takes precedence over llm_override if both are provided
|
||||
llm_overrides: list[LLMOverride] | None = None
|
||||
|
||||
# Allows the caller to override the temperature for the chat session
|
||||
# this does persist in the chat thread details
|
||||
temperature_override: float | None = None
|
||||
@@ -245,6 +250,9 @@ class ChatMessageDetail(BaseModel):
|
||||
error: str | None = None
|
||||
current_feedback: str | None = None # "like" | "dislike" | null
|
||||
|
||||
model_provider: str | None = None
|
||||
model_name: str | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
initial_dict["time_sent"] = self.time_sent.isoformat()
|
||||
|
||||
@@ -260,6 +260,7 @@ PacketObj = Union[
|
||||
class Packet(BaseModel):
|
||||
turn_index: int | None
|
||||
obj: Annotated[PacketObj, Field(discriminator="type")]
|
||||
model_id: str | None = None # Format: "{provider}:{model_name}" e.g. "openai:gpt-4"
|
||||
|
||||
|
||||
# This is for replaying it back from the DB to the frontend
|
||||
|
||||
@@ -10,6 +10,7 @@ import { EnterpriseSettings } from "@/app/admin/settings/interfaces";
|
||||
import { FileDescriptor } from "@/app/chat/interfaces";
|
||||
import { MemoizedAIMessage } from "../message/messageComponents/MemoizedAIMessage";
|
||||
import { ProjectFile } from "../projects/projectsService";
|
||||
import { ModelResponse } from "../message/messageComponents/ModelResponseTabs";
|
||||
|
||||
interface MessagesDisplayProps {
|
||||
messageHistory: Message[];
|
||||
@@ -76,6 +77,72 @@ export const MessagesDisplay: React.FC<MessagesDisplayProps> = ({
|
||||
// Stable fallbacks to avoid changing prop identities on each render
|
||||
const emptyDocs = useMemo<OnyxDocument[]>(() => [], []);
|
||||
const emptyChildrenIds = useMemo<number[]>(() => [], []);
|
||||
|
||||
// Build a map of parentNodeId -> sibling assistant Messages for multi-model/regeneration grouping
|
||||
// Uses completeMessageTree to know about ALL sibling groups across all branches
|
||||
// Also track which nodeIds should be skipped (all but the first in each group)
|
||||
const { siblingGroupMap, nodeIdsToSkip } = useMemo(() => {
|
||||
const groupMap = new Map<number, Message[]>();
|
||||
const skipNodeIds = new Set<number>();
|
||||
|
||||
// First, build group info from the COMPLETE message tree (all branches)
|
||||
// This ensures we know about sibling groups even when viewing a different branch
|
||||
if (completeMessageTree) {
|
||||
for (const msg of Array.from(completeMessageTree.values())) {
|
||||
if (msg.type === "assistant" && msg.parentNodeId !== null) {
|
||||
const existing = groupMap.get(msg.parentNodeId);
|
||||
if (existing) {
|
||||
existing.push(msg);
|
||||
} else {
|
||||
groupMap.set(msg.parentNodeId, [msg]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also include messages from messageHistory (for streaming/pre-created nodes
|
||||
// that might not be in completeMessageTree yet)
|
||||
for (const msg of messageHistory) {
|
||||
if (msg.type === "assistant" && msg.parentNodeId !== null) {
|
||||
const existing = groupMap.get(msg.parentNodeId);
|
||||
if (existing) {
|
||||
// Check if this message is already in the group (by nodeId)
|
||||
if (!existing.some((m) => m.nodeId === msg.nodeId)) {
|
||||
existing.push(msg);
|
||||
}
|
||||
} else {
|
||||
groupMap.set(msg.parentNodeId, [msg]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each group with multiple siblings, sort for consistent ordering
|
||||
// Primary sort: by model name (for consistent order after refresh)
|
||||
// Fallback sort: by nodeId (for stable order during streaming when model names aren't set yet)
|
||||
for (const [, messages] of Array.from(groupMap.entries())) {
|
||||
if (messages.length > 1) {
|
||||
messages.sort((a: Message, b: Message) => {
|
||||
const aName = `${a.modelProvider || ""}:${a.modelName || ""}`;
|
||||
const bName = `${b.modelProvider || ""}:${b.modelName || ""}`;
|
||||
// If both have model names, sort by model name
|
||||
// Otherwise, fall back to nodeId for stable ordering during streaming
|
||||
if (aName !== ":" && bName !== ":") {
|
||||
return aName.localeCompare(bName);
|
||||
}
|
||||
return a.nodeId - b.nodeId;
|
||||
});
|
||||
// Skip all except the first (representative) message
|
||||
for (let i = 1; i < messages.length; i++) {
|
||||
const msg = messages[i];
|
||||
if (msg) {
|
||||
skipNodeIds.add(msg.nodeId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { siblingGroupMap: groupMap, nodeIdsToSkip: skipNodeIds };
|
||||
}, [messageHistory, completeMessageTree]);
|
||||
const createRegenerator = useCallback(
|
||||
(regenerationRequest: {
|
||||
messageId: number;
|
||||
@@ -172,10 +239,56 @@ export const MessagesDisplay: React.FC<MessagesDisplayProps> = ({
|
||||
);
|
||||
}
|
||||
|
||||
// NOTE: it's fine to use the previous entry in messageHistory
|
||||
// since this is a "parsed" version of the message tree
|
||||
// so the previous message is guaranteed to be the parent of the current message
|
||||
const previousMessage = i !== 0 ? messageHistory[i - 1] : null;
|
||||
// For assistant messages, we need to find the actual parent (user message)
|
||||
// by looking up parentNodeId, since messageHistory may contain sibling
|
||||
// assistant messages (multi-model or regenerations) that share the same parent
|
||||
const previousMessage = (() => {
|
||||
if (i === 0) return null;
|
||||
// For assistant messages, find the parent by parentNodeId
|
||||
if (message.type === "assistant" && message.parentNodeId !== null) {
|
||||
return (
|
||||
messageHistory.find((m) => m.nodeId === message.parentNodeId) ??
|
||||
null
|
||||
);
|
||||
}
|
||||
// For user messages, the previous entry is the parent
|
||||
return messageHistory[i - 1] ?? null;
|
||||
})();
|
||||
|
||||
// Multi-model/regeneration sibling grouping:
|
||||
// Skip messages that are not the first in their sibling group
|
||||
if (nodeIdsToSkip.has(message.nodeId)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Build modelResponses from all sibling messages with the same parent
|
||||
let modelResponses: ModelResponse[] | undefined;
|
||||
|
||||
if (message.type === "assistant" && message.parentNodeId !== null) {
|
||||
const siblingMessages = siblingGroupMap.get(message.parentNodeId);
|
||||
if (siblingMessages && siblingMessages.length > 1) {
|
||||
modelResponses = siblingMessages.map((msg) => ({
|
||||
model: {
|
||||
name: msg.modelProvider || "",
|
||||
provider: msg.modelProvider || "",
|
||||
modelName: msg.modelName || "",
|
||||
},
|
||||
// Include the actual message for this model's response
|
||||
message: msg,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out non-representative messages from ALL sibling groups for branch switching
|
||||
// Each multi-model/regeneration group should appear as a single branch option
|
||||
const switchableMessages = (() => {
|
||||
const allChildren =
|
||||
parentMessage?.childrenNodeIds ?? emptyChildrenIds;
|
||||
// Filter out all nodeIds that are "non-representative" in their sibling group
|
||||
// nodeIdsToSkip contains all messages except the first one in each group
|
||||
return allChildren.filter((nodeId) => !nodeIdsToSkip.has(nodeId));
|
||||
})();
|
||||
|
||||
return (
|
||||
<div
|
||||
className="text-text"
|
||||
@@ -196,11 +309,11 @@ export const MessagesDisplay: React.FC<MessagesDisplayProps> = ({
|
||||
overriddenModel={llmManager.currentLlm?.modelName}
|
||||
nodeId={message.nodeId}
|
||||
llmManager={llmManager}
|
||||
otherMessagesCanSwitchTo={
|
||||
parentMessage?.childrenNodeIds ?? emptyChildrenIds
|
||||
}
|
||||
otherMessagesCanSwitchTo={switchableMessages}
|
||||
onMessageSelection={onMessageSelection}
|
||||
researchType={message.researchType}
|
||||
modelResponses={modelResponses}
|
||||
latestChildNodeId={parentMessage?.latestChildNodeId}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -17,7 +17,6 @@ import {
|
||||
MessageTreeState,
|
||||
upsertMessages,
|
||||
SYSTEM_NODE_ID,
|
||||
buildImmediateMessages,
|
||||
buildEmptyMessage,
|
||||
} from "../services/messageTree";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
@@ -592,69 +591,74 @@ export function useChatController({
|
||||
|
||||
// Add user message immediately to the message tree so that the chat
|
||||
// immediately reflects the user message
|
||||
// Assistant nodes are created on-demand when MessageResponseIDInfo arrives
|
||||
let initialUserNode: Message;
|
||||
let initialAssistantNode: Message;
|
||||
|
||||
if (regenerationRequest) {
|
||||
// For regeneration: keep the existing user message, only create new assistant
|
||||
// For regeneration: keep the existing user message
|
||||
initialUserNode = regenerationRequest.parentMessage;
|
||||
initialAssistantNode = buildEmptyMessage({
|
||||
messageType: "assistant",
|
||||
parentNodeId: initialUserNode.nodeId,
|
||||
nodeIdOffset: 1,
|
||||
});
|
||||
} else {
|
||||
// For new messages or editing: create/update user message and assistant
|
||||
// For new messages or editing: create/update user message
|
||||
const parentNodeIdForMessage = messageToResend
|
||||
? messageToResend.parentNodeId || SYSTEM_NODE_ID
|
||||
: parentMessage?.nodeId || SYSTEM_NODE_ID;
|
||||
const result = buildImmediateMessages(
|
||||
parentNodeIdForMessage,
|
||||
currMessage,
|
||||
projectFilesToFileDescriptors(currentMessageFiles),
|
||||
messageToResend
|
||||
);
|
||||
initialUserNode = result.initialUserNode;
|
||||
initialAssistantNode = result.initialAssistantNode;
|
||||
initialUserNode = messageToResend
|
||||
? { ...messageToResend }
|
||||
: buildEmptyMessage({
|
||||
messageType: "user",
|
||||
parentNodeId: parentNodeIdForMessage,
|
||||
message: currMessage,
|
||||
files: projectFilesToFileDescriptors(currentMessageFiles),
|
||||
});
|
||||
}
|
||||
|
||||
// make messages appear + clear input bar
|
||||
const messagesToUpsert = regenerationRequest
|
||||
? [initialAssistantNode] // Only upsert the new assistant for regeneration
|
||||
: [initialUserNode, initialAssistantNode]; // Upsert both for normal/edit flow
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: messagesToUpsert,
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId,
|
||||
});
|
||||
// Track assistant nodes (supports both single and multi-model)
|
||||
// Nodes are created on-demand when MessageResponseIDInfo arrives with backend IDs
|
||||
interface AssistantNodeData {
|
||||
node: Message;
|
||||
packets: Packet[];
|
||||
documents: OnyxDocument[];
|
||||
citations: CitationMap | null;
|
||||
finalMessage: BackendMessage | null;
|
||||
}
|
||||
// Map model_id (e.g. "openai:gpt-4") to index in assistantNodes array
|
||||
const modelIdToIndex: Map<string, number> = new Map();
|
||||
const assistantNodes: AssistantNodeData[] = [];
|
||||
|
||||
// Only upsert user message immediately; assistant nodes created when backend IDs arrive
|
||||
if (!regenerationRequest) {
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: [initialUserNode],
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId,
|
||||
});
|
||||
}
|
||||
resetInputBar();
|
||||
|
||||
let answer = "";
|
||||
|
||||
// Shared state across all models
|
||||
const stopReason: StreamStopReason | null = null;
|
||||
let query: string | null = null;
|
||||
let retrievalType: RetrievalType =
|
||||
selectedDocuments.length > 0
|
||||
? RetrievalType.SelectedDocs
|
||||
: RetrievalType.None;
|
||||
let documents: OnyxDocument[] = selectedDocuments;
|
||||
let citations: CitationMap | null = null;
|
||||
let aiMessageImages: FileDescriptor[] | null = null;
|
||||
let error: string | null = null;
|
||||
let stackTrace: string | null = null;
|
||||
|
||||
let finalMessage: BackendMessage | null = null;
|
||||
let toolCall: ToolCallMetadata | null = null;
|
||||
let files = projectFilesToFileDescriptors(currentMessageFiles);
|
||||
let packets: Packet[] = [];
|
||||
|
||||
let newUserMessageId: number | null = null;
|
||||
let newAssistantMessageId: number | null = null;
|
||||
|
||||
try {
|
||||
const lastSuccessfulMessageId = getLastSuccessfulMessageId(
|
||||
currentMessageTreeLocal
|
||||
);
|
||||
// Read the CURRENT message tree from store to get the correct parent
|
||||
// This is important for multi-model responses where tab switching updates
|
||||
// latestChildNodeId - we need the fresh state, not the captured closure
|
||||
const freshMessageTree =
|
||||
useChatSessionStore.getState().sessions.get(frozenSessionId)
|
||||
?.messageTree || currentMessageTreeLocal;
|
||||
const lastSuccessfulMessageId =
|
||||
getLastSuccessfulMessageId(freshMessageTree);
|
||||
const disabledToolIds = liveAssistant
|
||||
? assistantPreferences?.[liveAssistant?.id]?.disabled_tool_ids
|
||||
: undefined;
|
||||
@@ -704,6 +708,14 @@ export function useChatController({
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
|
||||
undefined,
|
||||
temperature: llmManager.temperature || undefined,
|
||||
// Multi-model support: if multiple LLMs are selected and no single override, send all selected models
|
||||
llmOverrides:
|
||||
!modelOverride && llmManager.selectedLlms.length > 1
|
||||
? llmManager.selectedLlms.map((llm) => ({
|
||||
model_provider: llm.name,
|
||||
model_version: llm.modelName,
|
||||
}))
|
||||
: undefined,
|
||||
systemPromptOverride:
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
|
||||
useExistingUserMessage: isSeededChat,
|
||||
@@ -746,8 +758,44 @@ export function useChatController({
|
||||
if (
|
||||
(packet as MessageResponseIDInfo).reserved_assistant_message_id
|
||||
) {
|
||||
newAssistantMessageId = (packet as MessageResponseIDInfo)
|
||||
.reserved_assistant_message_id;
|
||||
const msgInfo = packet as MessageResponseIDInfo;
|
||||
const messageId = msgInfo.reserved_assistant_message_id;
|
||||
|
||||
// Create assistant node on-demand with nodeId = messageId
|
||||
const newNode: Message = {
|
||||
nodeId: messageId,
|
||||
messageId: messageId,
|
||||
message: "",
|
||||
type: "assistant",
|
||||
files: [],
|
||||
toolCall: null,
|
||||
parentNodeId: initialUserNode.nodeId,
|
||||
packets: [],
|
||||
modelProvider: msgInfo.model_provider || undefined,
|
||||
modelName: msgInfo.model_name || undefined,
|
||||
};
|
||||
|
||||
const nodeIndex = assistantNodes.length;
|
||||
assistantNodes.push({
|
||||
node: newNode,
|
||||
packets: [],
|
||||
documents: selectedDocuments,
|
||||
citations: null,
|
||||
finalMessage: null,
|
||||
});
|
||||
|
||||
// Map model_id to index for packet routing
|
||||
if (msgInfo.model_provider && msgInfo.model_name) {
|
||||
const modelId = `${msgInfo.model_provider}:${msgInfo.model_name}`;
|
||||
modelIdToIndex.set(modelId, nodeIndex);
|
||||
}
|
||||
|
||||
// Upsert the new assistant node to the tree
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: [newNode],
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId,
|
||||
});
|
||||
}
|
||||
|
||||
if (Object.hasOwn(packet, "user_files")) {
|
||||
@@ -782,7 +830,16 @@ export function useChatController({
|
||||
|
||||
throw new Error((packet as StreamingError).error);
|
||||
} else if (Object.hasOwn(packet, "message_id")) {
|
||||
finalMessage = packet as BackendMessage;
|
||||
// Route finalMessage to target assistant
|
||||
const backendMsg = packet as BackendMessage;
|
||||
const msgModelId = (backendMsg as any).model_id as
|
||||
| string
|
||||
| undefined;
|
||||
const targetIndex = msgModelId
|
||||
? modelIdToIndex.get(msgModelId) ?? 0
|
||||
: 0;
|
||||
const targetNode = assistantNodes[targetIndex];
|
||||
if (targetNode) targetNode.finalMessage = backendMsg;
|
||||
} else if (Object.hasOwn(packet, "stop_reason")) {
|
||||
const stop_reason = (packet as StreamStopInfo).stop_reason;
|
||||
if (stop_reason === StreamStopReason.CONTEXT_LENGTH) {
|
||||
@@ -790,29 +847,35 @@ export function useChatController({
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "obj")) {
|
||||
console.debug("Object packet:", JSON.stringify(packet));
|
||||
packets.push(packet as Packet);
|
||||
const typedPacket = packet as Packet;
|
||||
|
||||
// Check if the packet contains document information
|
||||
const packetObj = (packet as Packet).obj;
|
||||
// Route packet to target assistant by model_id, or first assistant if single model
|
||||
const targetIndex = typedPacket.model_id
|
||||
? modelIdToIndex.get(typedPacket.model_id) ?? 0
|
||||
: 0;
|
||||
const target = assistantNodes[targetIndex];
|
||||
if (!target) continue;
|
||||
|
||||
target.packets.push(typedPacket);
|
||||
|
||||
// Update per-assistant state based on packet type
|
||||
const packetObj = typedPacket.obj;
|
||||
if (packetObj.type === "citation_info") {
|
||||
// Individual citation packet from backend streaming
|
||||
const citationInfo = packetObj as {
|
||||
type: "citation_info";
|
||||
citation_number: number;
|
||||
document_id: string;
|
||||
};
|
||||
// Incrementally build citations map
|
||||
citations = {
|
||||
...(citations || {}),
|
||||
target.citations = {
|
||||
...(target.citations || {}),
|
||||
[citationInfo.citation_number]: citationInfo.document_id,
|
||||
};
|
||||
} else if (packetObj.type === "citation_delta") {
|
||||
// Batched citation packet (for backwards compatibility)
|
||||
const citationDelta = packetObj as CitationDelta;
|
||||
if (citationDelta.citations) {
|
||||
citations = {
|
||||
...(citations || {}),
|
||||
target.citations = {
|
||||
...(target.citations || {}),
|
||||
...Object.fromEntries(
|
||||
citationDelta.citations.map((c) => [
|
||||
c.citation_num,
|
||||
@@ -824,10 +887,10 @@ export function useChatController({
|
||||
} else if (packetObj.type === "message_start") {
|
||||
const messageStart = packetObj as MessageStart;
|
||||
if (messageStart.final_documents) {
|
||||
documents = messageStart.final_documents;
|
||||
target.documents = messageStart.final_documents;
|
||||
updateSelectedNodeForDocDisplay(
|
||||
frozenSessionId,
|
||||
initialAssistantNode.nodeId
|
||||
target.node.nodeId
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -840,30 +903,42 @@ export function useChatController({
|
||||
parentMessage =
|
||||
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
|
||||
|
||||
// Build messages to upsert
|
||||
const messagesToUpsertInLoop: Message[] = [
|
||||
{
|
||||
...initialUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
},
|
||||
];
|
||||
|
||||
// Add all assistant nodes
|
||||
for (const assistantData of assistantNodes) {
|
||||
messagesToUpsertInLoop.push({
|
||||
...assistantData.node,
|
||||
message: error || "",
|
||||
type: error ? "error" : ("assistant" as const),
|
||||
retrievalType,
|
||||
query: assistantData.finalMessage?.rephrased_query || query,
|
||||
documents: assistantData.documents,
|
||||
citations:
|
||||
assistantData.finalMessage?.citations ||
|
||||
assistantData.citations ||
|
||||
{},
|
||||
files:
|
||||
assistantData.finalMessage?.files || aiMessageImages || [],
|
||||
toolCall: assistantData.finalMessage?.tool_call || toolCall,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: assistantData.finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
packets: assistantData.packets,
|
||||
modelProvider: assistantData.node.modelProvider,
|
||||
modelName: assistantData.node.modelName,
|
||||
});
|
||||
}
|
||||
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: [
|
||||
{
|
||||
...initialUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
},
|
||||
{
|
||||
...initialAssistantNode,
|
||||
messageId: newAssistantMessageId ?? undefined,
|
||||
message: error || answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: documents,
|
||||
citations: finalMessage?.citations || citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCall: finalMessage?.tool_call || toolCall,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
packets: packets,
|
||||
},
|
||||
],
|
||||
messages: messagesToUpsertInLoop,
|
||||
// Pass the latest map state
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId!,
|
||||
@@ -873,6 +948,9 @@ export function useChatController({
|
||||
} catch (e: any) {
|
||||
console.log("Error:", e);
|
||||
const errorMsg = e.message;
|
||||
// Use existing assistant node if available, otherwise create temp error node
|
||||
const errorNodeId =
|
||||
assistantNodes[0]?.node.nodeId ?? -1 * Date.now() - 1;
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: [
|
||||
{
|
||||
@@ -890,7 +968,7 @@ export function useChatController({
|
||||
packets: [],
|
||||
},
|
||||
{
|
||||
nodeId: initialAssistantNode.nodeId,
|
||||
nodeId: errorNodeId,
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
files: aiMessageImages || [],
|
||||
|
||||
@@ -143,6 +143,9 @@ export interface Message {
|
||||
|
||||
// feedback state
|
||||
currentFeedback?: FeedbackType | null;
|
||||
|
||||
modelProvider?: string;
|
||||
modelName?: string;
|
||||
}
|
||||
|
||||
export interface BackendChatSession {
|
||||
@@ -192,6 +195,9 @@ export interface BackendMessage {
|
||||
tool_call: ToolCallFinalResult | null;
|
||||
current_feedback: string | null;
|
||||
|
||||
model_provider: string | null;
|
||||
model_name: string | null;
|
||||
|
||||
sub_questions: SubQuestionDetail[];
|
||||
// Keeping existing properties
|
||||
comments: any;
|
||||
@@ -203,6 +209,8 @@ export interface BackendMessage {
|
||||
export interface MessageResponseIDInfo {
|
||||
user_message_id: number | null;
|
||||
reserved_assistant_message_id: number;
|
||||
model_provider?: string | null;
|
||||
model_name?: string | null;
|
||||
}
|
||||
|
||||
export interface UserKnowledgeFilePacket {
|
||||
|
||||
@@ -13,7 +13,14 @@ import { FeedbackType } from "@/app/chat/interfaces";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
import CitedSourcesToggle from "@/app/chat/message/messageComponents/CitedSourcesToggle";
|
||||
import { TooltipGroup } from "@/components/tooltip/CustomTooltip";
|
||||
import { useRef, useState, useEffect, useCallback, RefObject } from "react";
|
||||
import {
|
||||
useRef,
|
||||
useState,
|
||||
useEffect,
|
||||
useCallback,
|
||||
useMemo,
|
||||
RefObject,
|
||||
} from "react";
|
||||
import {
|
||||
useChatSessionStore,
|
||||
useDocumentSidebarVisible,
|
||||
@@ -46,6 +53,11 @@ import FeedbackModal, {
|
||||
} from "../../components/modal/FeedbackModal";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { useFeedbackController } from "../../hooks/useFeedbackController";
|
||||
import {
|
||||
ModelResponse,
|
||||
ModelResponseTabs,
|
||||
useModelResponses,
|
||||
} from "./ModelResponseTabs";
|
||||
|
||||
export interface AIMessageProps {
|
||||
rawPackets: Packet[];
|
||||
@@ -56,6 +68,11 @@ export interface AIMessageProps {
|
||||
llmManager: LlmManager | null;
|
||||
otherMessagesCanSwitchTo?: number[];
|
||||
onMessageSelection?: (nodeId: number) => void;
|
||||
// Multi-model responses: when multiple models are selected, each has its own response
|
||||
// If undefined or length <= 1, renders normally without tabs
|
||||
modelResponses?: ModelResponse[];
|
||||
// The nodeId of the latest/selected child in the sibling group (for tab selection)
|
||||
latestChildNodeId?: number | null;
|
||||
}
|
||||
|
||||
export default function AIMessage({
|
||||
@@ -67,11 +84,56 @@ export default function AIMessage({
|
||||
llmManager,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
modelResponses,
|
||||
latestChildNodeId,
|
||||
}: AIMessageProps) {
|
||||
const markdownRef = useRef<HTMLDivElement>(null);
|
||||
const { popup, setPopup } = usePopup();
|
||||
const { handleFeedbackChange } = useFeedbackController({ setPopup });
|
||||
|
||||
// Multi-model response state
|
||||
const {
|
||||
activeIndex: activeModelIndex,
|
||||
setActiveIndex: setActiveModelIndex,
|
||||
hasMultipleResponses,
|
||||
activeResponse,
|
||||
} = useModelResponses(modelResponses, latestChildNodeId);
|
||||
|
||||
// Handler for tab changes - switches branches via onMessageSelection
|
||||
// This updates latestChildNodeId and persists to backend, enabling true branching
|
||||
const handleTabChange = useCallback(
|
||||
(index: number) => {
|
||||
setActiveModelIndex(index);
|
||||
// Switch branches via onMessageSelection - this updates the message tree
|
||||
// and persists to backend so subsequent messages branch from this response
|
||||
const selectedMessage = modelResponses?.[index]?.message;
|
||||
if (selectedMessage?.nodeId && onMessageSelection) {
|
||||
onMessageSelection(selectedMessage.nodeId);
|
||||
}
|
||||
},
|
||||
[setActiveModelIndex, modelResponses, onMessageSelection]
|
||||
);
|
||||
|
||||
// DEBUG: Log modelResponses
|
||||
if (modelResponses && modelResponses.length > 0) {
|
||||
console.log(
|
||||
"[AIMessage] nodeId:",
|
||||
nodeId,
|
||||
"modelResponses:",
|
||||
modelResponses.length,
|
||||
"hasMultipleResponses:",
|
||||
hasMultipleResponses
|
||||
);
|
||||
}
|
||||
|
||||
// When in multi-model mode, use the active response's packets instead of rawPackets
|
||||
const effectivePackets = useMemo(() => {
|
||||
if (hasMultipleResponses && activeResponse?.message?.packets) {
|
||||
return activeResponse.message.packets;
|
||||
}
|
||||
return rawPackets;
|
||||
}, [hasMultipleResponses, activeResponse, rawPackets]);
|
||||
|
||||
const modal = useCreateModal();
|
||||
const [feedbackModalProps, setFeedbackModalProps] =
|
||||
useState<FeedbackModalProps | null>(null);
|
||||
@@ -139,7 +201,8 @@ export default function AIMessage({
|
||||
);
|
||||
|
||||
const [finalAnswerComing, _setFinalAnswerComing] = useState(
|
||||
isFinalAnswerComing(rawPackets) || isStreamingComplete(rawPackets)
|
||||
isFinalAnswerComing(effectivePackets) ||
|
||||
isStreamingComplete(effectivePackets)
|
||||
);
|
||||
const setFinalAnswerComing = (value: boolean) => {
|
||||
_setFinalAnswerComing(value);
|
||||
@@ -147,7 +210,7 @@ export default function AIMessage({
|
||||
};
|
||||
|
||||
const [displayComplete, _setDisplayComplete] = useState(
|
||||
isStreamingComplete(rawPackets)
|
||||
isStreamingComplete(effectivePackets)
|
||||
);
|
||||
const setDisplayComplete = (value: boolean) => {
|
||||
_setDisplayComplete(value);
|
||||
@@ -155,7 +218,7 @@ export default function AIMessage({
|
||||
};
|
||||
|
||||
const [stopPacketSeen, _setStopPacketSeen] = useState(
|
||||
isStreamingComplete(rawPackets)
|
||||
isStreamingComplete(effectivePackets)
|
||||
);
|
||||
const setStopPacketSeen = (value: boolean) => {
|
||||
_setStopPacketSeen(value);
|
||||
@@ -173,12 +236,20 @@ export default function AIMessage({
|
||||
const groupedPacketsRef = useRef<{ turn_index: number; packets: Packet[] }[]>(
|
||||
[]
|
||||
);
|
||||
const finalAnswerComingRef = useRef<boolean>(isFinalAnswerComing(rawPackets));
|
||||
const displayCompleteRef = useRef<boolean>(isStreamingComplete(rawPackets));
|
||||
const stopPacketSeenRef = useRef<boolean>(isStreamingComplete(rawPackets));
|
||||
const finalAnswerComingRef = useRef<boolean>(
|
||||
isFinalAnswerComing(effectivePackets)
|
||||
);
|
||||
const displayCompleteRef = useRef<boolean>(
|
||||
isStreamingComplete(effectivePackets)
|
||||
);
|
||||
const stopPacketSeenRef = useRef<boolean>(
|
||||
isStreamingComplete(effectivePackets)
|
||||
);
|
||||
// Track turn_index values for graceful SECTION_END injection
|
||||
const seenTurnIndicesRef = useRef<Set<number>>(new Set());
|
||||
const turnIndicesWithSectionEndRef = useRef<Set<number>>(new Set());
|
||||
// Track previous activeModelIndex for synchronous reset detection
|
||||
const prevActiveModelIndexRef = useRef<number>(activeModelIndex);
|
||||
|
||||
// Reset incremental state when switching messages or when stream resets
|
||||
const resetState = () => {
|
||||
@@ -189,9 +260,9 @@ export default function AIMessage({
|
||||
documentMapRef.current = new Map();
|
||||
groupedPacketsMapRef.current = new Map();
|
||||
groupedPacketsRef.current = [];
|
||||
finalAnswerComingRef.current = isFinalAnswerComing(rawPackets);
|
||||
displayCompleteRef.current = isStreamingComplete(rawPackets);
|
||||
stopPacketSeenRef.current = isStreamingComplete(rawPackets);
|
||||
finalAnswerComingRef.current = isFinalAnswerComing(effectivePackets);
|
||||
displayCompleteRef.current = isStreamingComplete(effectivePackets);
|
||||
stopPacketSeenRef.current = isStreamingComplete(effectivePackets);
|
||||
seenTurnIndicesRef.current = new Set();
|
||||
turnIndicesWithSectionEndRef.current = new Set();
|
||||
};
|
||||
@@ -199,8 +270,19 @@ export default function AIMessage({
|
||||
resetState();
|
||||
}, [nodeId]);
|
||||
|
||||
// SYNCHRONOUS reset when switching model tabs - must happen BEFORE packet processing
|
||||
// useEffect runs AFTER render, but packet processing happens DURING render,
|
||||
// so we need to detect tab changes synchronously to avoid mixing packets
|
||||
if (
|
||||
hasMultipleResponses &&
|
||||
prevActiveModelIndexRef.current !== activeModelIndex
|
||||
) {
|
||||
resetState();
|
||||
prevActiveModelIndexRef.current = activeModelIndex;
|
||||
}
|
||||
|
||||
// If the upstream replaces packets with a shorter list (reset), clear state
|
||||
if (lastProcessedIndexRef.current > rawPackets.length) {
|
||||
if (lastProcessedIndexRef.current > effectivePackets.length) {
|
||||
resetState();
|
||||
}
|
||||
|
||||
@@ -239,9 +321,13 @@ export default function AIMessage({
|
||||
};
|
||||
|
||||
// Process only the new packets synchronously for this render
|
||||
if (rawPackets.length > lastProcessedIndexRef.current) {
|
||||
for (let i = lastProcessedIndexRef.current; i < rawPackets.length; i++) {
|
||||
const packet = rawPackets[i];
|
||||
if (effectivePackets.length > lastProcessedIndexRef.current) {
|
||||
for (
|
||||
let i = lastProcessedIndexRef.current;
|
||||
i < effectivePackets.length;
|
||||
i++
|
||||
) {
|
||||
const packet = effectivePackets[i];
|
||||
if (!packet) continue;
|
||||
|
||||
const currentTurnIndex = packet.turn_index;
|
||||
@@ -368,7 +454,7 @@ export default function AIMessage({
|
||||
.filter(({ packets }) => hasContentPackets(packets))
|
||||
.sort((a, b) => a.turn_index - b.turn_index);
|
||||
|
||||
lastProcessedIndexRef.current = rawPackets.length;
|
||||
lastProcessedIndexRef.current = effectivePackets.length;
|
||||
}
|
||||
|
||||
const citations = citationsRef.current;
|
||||
@@ -431,6 +517,26 @@ export default function AIMessage({
|
||||
<div className="w-full">
|
||||
<div className="max-w-message-max break-words">
|
||||
<div className="w-full desktop:ml-4">
|
||||
{/* Multi-model response tabs */}
|
||||
{hasMultipleResponses && modelResponses && (
|
||||
<ModelResponseTabs
|
||||
modelResponses={modelResponses}
|
||||
activeIndex={activeModelIndex}
|
||||
onTabChange={handleTabChange}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Show active model name badge when multiple responses */}
|
||||
{hasMultipleResponses && activeResponse && (
|
||||
<div className="text-xs text-text-03 mb-2">
|
||||
Response from{" "}
|
||||
<span className="font-medium text-text-02">
|
||||
{activeResponse.model.modelName ||
|
||||
activeResponse.model.provider}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="max-w-message-max break-words">
|
||||
<div
|
||||
ref={markdownRef}
|
||||
@@ -550,7 +656,9 @@ export default function AIMessage({
|
||||
)}
|
||||
|
||||
<CopyIconButton
|
||||
getCopyText={() => getTextContent(rawPackets)}
|
||||
getCopyText={() =>
|
||||
getTextContent(effectivePackets)
|
||||
}
|
||||
tertiary
|
||||
data-testid="AIMessage/copy-button"
|
||||
/>
|
||||
@@ -583,13 +691,20 @@ export default function AIMessage({
|
||||
<div data-testid="AIMessage/regenerate">
|
||||
<LLMPopover
|
||||
llmManager={llmManager}
|
||||
currentModelName={chatState.overriddenModel}
|
||||
currentModelName={
|
||||
// Use the model from the active response (for multi-model)
|
||||
// or fall back to session-level override
|
||||
activeResponse?.message?.modelName ||
|
||||
activeResponse?.model.modelName ||
|
||||
chatState.overriddenModel
|
||||
}
|
||||
onSelect={(modelName) => {
|
||||
const llmDescriptor =
|
||||
parseLlmDescriptor(modelName);
|
||||
chatState.regenerate!(llmDescriptor);
|
||||
}}
|
||||
folded
|
||||
singleSelectMode={true}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { OnyxDocument, MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import AIMessage from "./AIMessage";
|
||||
import { LlmDescriptor, LlmManager } from "@/lib/hooks";
|
||||
import { ProjectFile } from "@/app/chat/projects/projectsService";
|
||||
import { ModelResponse } from "./ModelResponseTabs";
|
||||
|
||||
interface BaseMemoizedAIMessageProps {
|
||||
rawPackets: any[];
|
||||
@@ -21,6 +22,10 @@ interface BaseMemoizedAIMessageProps {
|
||||
llmManager: LlmManager | null;
|
||||
projectFiles?: ProjectFile[];
|
||||
researchType?: string | null;
|
||||
// Multi-model responses
|
||||
modelResponses?: ModelResponse[];
|
||||
// The nodeId of the latest/selected child in the sibling group (for tab selection)
|
||||
latestChildNodeId?: number | null;
|
||||
}
|
||||
|
||||
interface InternalMemoizedAIMessageProps extends BaseMemoizedAIMessageProps {
|
||||
@@ -54,6 +59,8 @@ const InternalMemoizedAIMessage = React.memo(
|
||||
llmManager,
|
||||
projectFiles,
|
||||
researchType,
|
||||
modelResponses,
|
||||
latestChildNodeId,
|
||||
}: InternalMemoizedAIMessageProps) {
|
||||
const chatState = React.useMemo(
|
||||
() => ({
|
||||
@@ -88,6 +95,8 @@ const InternalMemoizedAIMessage = React.memo(
|
||||
llmManager={llmManager}
|
||||
otherMessagesCanSwitchTo={otherMessagesCanSwitchTo}
|
||||
onMessageSelection={onMessageSelection}
|
||||
modelResponses={modelResponses}
|
||||
latestChildNodeId={latestChildNodeId}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -110,6 +119,8 @@ export const MemoizedAIMessage = ({
|
||||
llmManager,
|
||||
projectFiles,
|
||||
researchType,
|
||||
modelResponses,
|
||||
latestChildNodeId,
|
||||
}: MemoizedAIMessageProps) => {
|
||||
const regenerate = useMemo(() => {
|
||||
if (messageId === undefined) {
|
||||
@@ -145,6 +156,8 @@ export const MemoizedAIMessage = ({
|
||||
llmManager={llmManager}
|
||||
projectFiles={projectFiles}
|
||||
researchType={researchType}
|
||||
modelResponses={modelResponses}
|
||||
latestChildNodeId={latestChildNodeId}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
154
web/src/app/chat/message/messageComponents/ModelResponseTabs.tsx
Normal file
154
web/src/app/chat/message/messageComponents/ModelResponseTabs.tsx
Normal file
@@ -0,0 +1,154 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useMemo, useEffect, useRef } from "react";
|
||||
import { LlmDescriptor } from "@/lib/hooks";
|
||||
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { isStreamingComplete } from "@/app/chat/services/packetUtils";
|
||||
import { FiCheck } from "react-icons/fi";
|
||||
|
||||
import { Message } from "@/app/chat/interfaces";
|
||||
|
||||
export interface ModelResponse {
|
||||
model: LlmDescriptor;
|
||||
// The actual message data for this model's response
|
||||
message?: Message;
|
||||
}
|
||||
|
||||
interface ModelResponseTabsProps {
|
||||
modelResponses: ModelResponse[];
|
||||
activeIndex: number;
|
||||
onTabChange: (index: number) => void;
|
||||
}
|
||||
|
||||
// Streaming indicator component - pulsing dot animation
|
||||
function StreamingIndicator() {
|
||||
return (
|
||||
<span className="relative flex h-2 w-2">
|
||||
<span className="animate-ping absolute inline-flex h-full w-full rounded-full bg-amber-400 opacity-75" />
|
||||
<span className="relative inline-flex rounded-full h-2 w-2 bg-amber-500" />
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
// Completion indicator component - checkmark
|
||||
function CompletedIndicator() {
|
||||
return (
|
||||
<span className="flex items-center justify-center h-3.5 w-3.5 rounded-full bg-emerald-500/20">
|
||||
<FiCheck className="h-2.5 w-2.5 text-emerald-600" strokeWidth={3} />
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export function ModelResponseTabs({
|
||||
modelResponses,
|
||||
activeIndex,
|
||||
onTabChange,
|
||||
}: ModelResponseTabsProps) {
|
||||
if (modelResponses.length <= 1) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-1 mb-3 pb-2 border-b border-border-01">
|
||||
{modelResponses.map((response, index) => {
|
||||
const isActive = index === activeIndex;
|
||||
const Icon = getProviderIcon(
|
||||
response.model.provider,
|
||||
response.model.modelName
|
||||
);
|
||||
|
||||
// Determine streaming status for this model's response
|
||||
const packets = response.message?.packets || [];
|
||||
const isComplete = isStreamingComplete(packets);
|
||||
const hasStartedStreaming = packets.length > 0;
|
||||
|
||||
return (
|
||||
<button
|
||||
key={`${response.model.provider}-${response.model.modelName}-${index}`}
|
||||
onClick={() => onTabChange(index)}
|
||||
className={cn(
|
||||
"flex items-center gap-1.5 px-3 py-1.5 rounded-lg text-sm transition-colors",
|
||||
isActive
|
||||
? "bg-background-emphasis text-text-01 font-medium"
|
||||
: "text-text-03 hover:bg-background-hover hover:text-text-02"
|
||||
)}
|
||||
>
|
||||
<Icon size={16} />
|
||||
<span className="max-w-[120px] truncate">
|
||||
{response.model.modelName || response.model.provider}
|
||||
</span>
|
||||
{/* Status indicator: streaming or complete */}
|
||||
{hasStartedStreaming && !isComplete && <StreamingIndicator />}
|
||||
{isComplete && <CompletedIndicator />}
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Hook to manage multi-model response state
|
||||
export function useModelResponses(
|
||||
modelResponses?: ModelResponse[],
|
||||
latestChildNodeId?: number | null
|
||||
) {
|
||||
// Calculate initial index based on latestChildNodeId
|
||||
const initialIndex = useMemo(() => {
|
||||
if (!modelResponses || modelResponses.length === 0) return 0;
|
||||
if (latestChildNodeId === undefined || latestChildNodeId === null) return 0;
|
||||
|
||||
const index = modelResponses.findIndex(
|
||||
(r) => r.message?.nodeId === latestChildNodeId
|
||||
);
|
||||
return index >= 0 ? index : 0;
|
||||
}, [modelResponses, latestChildNodeId]);
|
||||
|
||||
const [activeIndex, setActiveIndex] = useState(initialIndex);
|
||||
|
||||
// Track previous modelResponses length to detect new responses (regeneration)
|
||||
const prevLengthRef = useRef(modelResponses?.length ?? 0);
|
||||
|
||||
// Auto-switch to new tab when a response is added (regeneration case)
|
||||
// Also sync with latestChildNodeId when it changes (e.g., on load or branch switch)
|
||||
useEffect(() => {
|
||||
const currentLength = modelResponses?.length ?? 0;
|
||||
|
||||
// If a new response was added, switch to it (regeneration case)
|
||||
if (currentLength > prevLengthRef.current && currentLength > 0) {
|
||||
setActiveIndex(currentLength - 1);
|
||||
}
|
||||
// If latestChildNodeId changed (e.g., loading chat), sync to it
|
||||
else if (
|
||||
latestChildNodeId !== undefined &&
|
||||
latestChildNodeId !== null &&
|
||||
modelResponses
|
||||
) {
|
||||
const targetIndex = modelResponses.findIndex(
|
||||
(r) => r.message?.nodeId === latestChildNodeId
|
||||
);
|
||||
if (targetIndex >= 0 && targetIndex !== activeIndex) {
|
||||
setActiveIndex(targetIndex);
|
||||
}
|
||||
}
|
||||
|
||||
prevLengthRef.current = currentLength;
|
||||
}, [modelResponses, latestChildNodeId, activeIndex]);
|
||||
|
||||
// Reset active index if it's out of bounds
|
||||
const safeActiveIndex = useMemo(() => {
|
||||
if (!modelResponses || modelResponses.length === 0) return 0;
|
||||
return Math.min(activeIndex, modelResponses.length - 1);
|
||||
}, [activeIndex, modelResponses]);
|
||||
|
||||
const hasMultipleResponses = (modelResponses?.length ?? 0) > 1;
|
||||
|
||||
const activeResponse = modelResponses?.[safeActiveIndex];
|
||||
|
||||
return {
|
||||
activeIndex: safeActiveIndex,
|
||||
setActiveIndex,
|
||||
hasMultipleResponses,
|
||||
activeResponse,
|
||||
};
|
||||
}
|
||||
@@ -180,6 +180,7 @@ export interface SendMessageParams {
|
||||
useAgentSearch?: boolean;
|
||||
enabledToolIds?: number[];
|
||||
forcedToolIds?: number[];
|
||||
llmOverrides?: { model_provider: string; model_version: string }[];
|
||||
}
|
||||
|
||||
export async function* sendMessage({
|
||||
@@ -203,6 +204,7 @@ export async function* sendMessage({
|
||||
useAgentSearch,
|
||||
enabledToolIds,
|
||||
forcedToolIds,
|
||||
llmOverrides,
|
||||
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
|
||||
const documentsAreSelected =
|
||||
selectedDocumentIds && selectedDocumentIds.length > 0;
|
||||
@@ -232,14 +234,18 @@ export async function* sendMessage({
|
||||
system_prompt: systemPromptOverride,
|
||||
}
|
||||
: null,
|
||||
// Multi-model response support: if llmOverrides is provided with multiple models,
|
||||
// send it as llm_overrides; otherwise use single llm_override for backwards compatibility
|
||||
llm_override:
|
||||
temperature || modelVersion
|
||||
!llmOverrides && (temperature || modelVersion)
|
||||
? {
|
||||
temperature,
|
||||
model_provider: modelProvider,
|
||||
model_version: modelVersion,
|
||||
}
|
||||
: null,
|
||||
llm_overrides:
|
||||
llmOverrides && llmOverrides.length > 0 ? llmOverrides : null,
|
||||
use_existing_user_message: useExistingUserMessage,
|
||||
use_agentic_search: useAgentSearch ?? false,
|
||||
allowed_tool_ids: enabledToolIds,
|
||||
@@ -550,6 +556,8 @@ export function processRawChatHistory(
|
||||
overridden_model: messageInfo.overridden_model,
|
||||
packets: packetsForMessage || [],
|
||||
currentFeedback: messageInfo.current_feedback as FeedbackType | null,
|
||||
modelProvider: messageInfo.model_provider ?? undefined,
|
||||
modelName: messageInfo.model_name ?? undefined,
|
||||
};
|
||||
|
||||
messages.set(messageInfo.message_id, message);
|
||||
|
||||
@@ -254,6 +254,17 @@ export function getLatestMessageChain(messages: MessageTreeState): Message[] {
|
||||
return chain;
|
||||
}
|
||||
|
||||
// Build a map of parentNodeId -> sibling assistant messages
|
||||
// This is used to include all multi-model responses (or regenerations) when we encounter one
|
||||
const siblingAssistantMap = new Map<number, Message[]>();
|
||||
for (const msg of Array.from(messages.values())) {
|
||||
if (msg.type === "assistant" && msg.parentNodeId !== null) {
|
||||
const existing = siblingAssistantMap.get(msg.parentNodeId) || [];
|
||||
existing.push(msg);
|
||||
siblingAssistantMap.set(msg.parentNodeId, existing);
|
||||
}
|
||||
}
|
||||
|
||||
// Find the root message
|
||||
let root: Message | undefined;
|
||||
if (messages.has(SYSTEM_NODE_ID)) {
|
||||
@@ -285,6 +296,9 @@ export function getLatestMessageChain(messages: MessageTreeState): Message[] {
|
||||
chain.push(root);
|
||||
}
|
||||
|
||||
// Track which parent nodes we've already added sibling groups for (to avoid duplicates)
|
||||
const addedSiblingGroups = new Set<number>();
|
||||
|
||||
while (
|
||||
currentMessage?.latestChildNodeId !== null &&
|
||||
currentMessage?.latestChildNodeId !== undefined
|
||||
@@ -292,8 +306,46 @@ export function getLatestMessageChain(messages: MessageTreeState): Message[] {
|
||||
const nextNodeId = currentMessage.latestChildNodeId;
|
||||
const nextMessage = messages.get(nextNodeId);
|
||||
if (nextMessage) {
|
||||
chain.push(nextMessage);
|
||||
currentMessage = nextMessage;
|
||||
// Check if this is an assistant message that might have siblings (multi-model or regenerations)
|
||||
if (
|
||||
nextMessage.type === "assistant" &&
|
||||
nextMessage.parentNodeId !== null &&
|
||||
!addedSiblingGroups.has(nextMessage.parentNodeId)
|
||||
) {
|
||||
// Get all sibling assistant messages that share the same parent
|
||||
const siblingMessages = siblingAssistantMap.get(
|
||||
nextMessage.parentNodeId
|
||||
);
|
||||
if (siblingMessages && siblingMessages.length > 1) {
|
||||
// Multiple siblings - add ALL of them for tabbed display
|
||||
// Sort by nodeId to ensure consistent ordering
|
||||
const sortedSiblings = [...siblingMessages].sort(
|
||||
(a, b) => a.nodeId - b.nodeId
|
||||
);
|
||||
for (const msg of sortedSiblings) {
|
||||
chain.push(msg);
|
||||
}
|
||||
addedSiblingGroups.add(nextMessage.parentNodeId);
|
||||
// Continue from the one that was "latest" (the one we were following)
|
||||
currentMessage = nextMessage;
|
||||
} else {
|
||||
// Only one assistant child, add normally
|
||||
chain.push(nextMessage);
|
||||
addedSiblingGroups.add(nextMessage.parentNodeId);
|
||||
currentMessage = nextMessage;
|
||||
}
|
||||
} else if (
|
||||
nextMessage.type === "assistant" &&
|
||||
nextMessage.parentNodeId !== null &&
|
||||
addedSiblingGroups.has(nextMessage.parentNodeId)
|
||||
) {
|
||||
// Already added this sibling group, just move to next
|
||||
currentMessage = nextMessage;
|
||||
} else {
|
||||
// Normal message (user message or no parent)
|
||||
chain.push(nextMessage);
|
||||
currentMessage = nextMessage;
|
||||
}
|
||||
} else {
|
||||
console.warn(
|
||||
`Chain broken: Message with nodeId ${nextNodeId} not found.`
|
||||
@@ -414,6 +466,8 @@ interface BuildEmptyMessageParams {
|
||||
message?: string;
|
||||
files?: FileDescriptor[];
|
||||
nodeIdOffset?: number;
|
||||
modelProvider?: string;
|
||||
modelName?: string;
|
||||
}
|
||||
|
||||
export const buildEmptyMessage = (params: BuildEmptyMessageParams): Message => {
|
||||
@@ -427,6 +481,8 @@ export const buildEmptyMessage = (params: BuildEmptyMessageParams): Message => {
|
||||
toolCall: null,
|
||||
parentNodeId: params.parentNodeId,
|
||||
packets: [],
|
||||
modelProvider: params.modelProvider,
|
||||
modelName: params.modelName,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -226,6 +226,7 @@ export type ObjTypes =
|
||||
export interface Packet {
|
||||
turn_index: number;
|
||||
obj: ObjTypes;
|
||||
model_id?: string | null; // Format: "{provider}:{model_name}" e.g. "openai:gpt-4"
|
||||
}
|
||||
|
||||
export interface ChatPacket {
|
||||
|
||||
@@ -484,7 +484,12 @@ export interface LlmDescriptor {
|
||||
}
|
||||
|
||||
export interface LlmManager {
|
||||
// Multi-model selection (array of 1-4 models)
|
||||
selectedLlms: LlmDescriptor[];
|
||||
updateSelectedLlms: (llms: LlmDescriptor[]) => void;
|
||||
// Convenience getter for backwards compatibility - returns first selected model
|
||||
currentLlm: LlmDescriptor;
|
||||
// Legacy single-model update - updates first model in selection
|
||||
updateCurrentLlm: (newOverride: LlmDescriptor) => void;
|
||||
temperature: number;
|
||||
updateTemperature: (temperature: number) => void;
|
||||
@@ -564,11 +569,20 @@ export function useLlmManager(
|
||||
const [userHasManuallyOverriddenLLM, setUserHasManuallyOverriddenLLM] =
|
||||
useState(false);
|
||||
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
||||
const [currentLlm, setCurrentLlm] = useState<LlmDescriptor>({
|
||||
name: "",
|
||||
provider: "",
|
||||
modelName: "",
|
||||
});
|
||||
const [selectedLlms, setSelectedLlms] = useState<LlmDescriptor[]>([
|
||||
{ name: "", provider: "", modelName: "" },
|
||||
]);
|
||||
|
||||
// Convenience getter for backwards compatibility
|
||||
const currentLlm = useMemo(
|
||||
() => selectedLlms[0] || { name: "", provider: "", modelName: "" },
|
||||
[selectedLlms]
|
||||
);
|
||||
|
||||
// Helper to set the primary (first) model while preserving others
|
||||
const setPrimaryLlm = (llm: LlmDescriptor) => {
|
||||
setSelectedLlms((prev) => [llm, ...prev.slice(1)]);
|
||||
};
|
||||
|
||||
const llmUpdate = () => {
|
||||
/* Should be called when the live assistant or current chat session changes */
|
||||
@@ -588,11 +602,11 @@ export function useLlmManager(
|
||||
}
|
||||
|
||||
if (currentChatSession?.current_alternate_model) {
|
||||
setCurrentLlm(
|
||||
setPrimaryLlm(
|
||||
getValidLlmDescriptor(currentChatSession.current_alternate_model)
|
||||
);
|
||||
} else if (liveAssistant?.llm_model_version_override) {
|
||||
setCurrentLlm(
|
||||
setPrimaryLlm(
|
||||
getValidLlmDescriptor(liveAssistant.llm_model_version_override)
|
||||
);
|
||||
} else if (userHasManuallyOverriddenLLM) {
|
||||
@@ -600,14 +614,14 @@ export function useLlmManager(
|
||||
// current chat session, use the override
|
||||
return;
|
||||
} else if (user?.preferences?.default_model) {
|
||||
setCurrentLlm(getValidLlmDescriptor(user.preferences.default_model));
|
||||
setPrimaryLlm(getValidLlmDescriptor(user.preferences.default_model));
|
||||
} else {
|
||||
const defaultProvider = llmProviders.find(
|
||||
(provider) => provider.is_default_provider
|
||||
);
|
||||
|
||||
if (defaultProvider) {
|
||||
setCurrentLlm({
|
||||
setPrimaryLlm({
|
||||
name: defaultProvider.name,
|
||||
provider: defaultProvider.provider,
|
||||
modelName: defaultProvider.default_model_name,
|
||||
@@ -665,20 +679,30 @@ export function useLlmManager(
|
||||
setImageFilesPresent(present);
|
||||
};
|
||||
|
||||
// Manually set the LLM
|
||||
// Update all selected LLMs (multi-select)
|
||||
const updateSelectedLlms = (llms: LlmDescriptor[]) => {
|
||||
// Ensure at least one model is selected
|
||||
if (llms.length === 0) {
|
||||
return;
|
||||
}
|
||||
setSelectedLlms(llms);
|
||||
setUserHasManuallyOverriddenLLM(true);
|
||||
};
|
||||
|
||||
// Legacy: Manually set the primary LLM (backwards compatibility)
|
||||
const updateCurrentLlm = (newLlm: LlmDescriptor) => {
|
||||
setCurrentLlm(newLlm);
|
||||
setPrimaryLlm(newLlm);
|
||||
setUserHasManuallyOverriddenLLM(true);
|
||||
};
|
||||
|
||||
const updateCurrentLlmToModelName = (modelName: string) => {
|
||||
setCurrentLlm(getValidLlmDescriptor(modelName));
|
||||
setPrimaryLlm(getValidLlmDescriptor(modelName));
|
||||
setUserHasManuallyOverriddenLLM(true);
|
||||
};
|
||||
|
||||
const updateModelOverrideBasedOnChatSession = (chatSession?: ChatSession) => {
|
||||
if (chatSession && chatSession.current_alternate_model?.length > 0) {
|
||||
setCurrentLlm(getValidLlmDescriptor(chatSession.current_alternate_model));
|
||||
setPrimaryLlm(getValidLlmDescriptor(chatSession.current_alternate_model));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -757,6 +781,10 @@ export function useLlmManager(
|
||||
|
||||
return {
|
||||
updateModelOverrideBasedOnChatSession,
|
||||
// Multi-model selection
|
||||
selectedLlms,
|
||||
updateSelectedLlms,
|
||||
// Backwards compatibility (single model)
|
||||
currentLlm,
|
||||
updateCurrentLlm,
|
||||
temperature,
|
||||
|
||||
@@ -53,6 +53,11 @@ export interface LLMPopoverProps {
|
||||
onSelect?: (value: string) => void;
|
||||
currentModelName?: string;
|
||||
disabled?: boolean;
|
||||
maxSelection?: number;
|
||||
minSelection?: number;
|
||||
// Single-select mode: for one-shot actions like regeneration
|
||||
// Selects model, calls onSelect, and closes popover (doesn't update llmManager.selectedLlms)
|
||||
singleSelectMode?: boolean;
|
||||
}
|
||||
|
||||
export default function LLMPopover({
|
||||
@@ -61,14 +66,35 @@ export default function LLMPopover({
|
||||
onSelect,
|
||||
currentModelName,
|
||||
disabled = false,
|
||||
maxSelection = 4,
|
||||
minSelection = 1,
|
||||
singleSelectMode = false,
|
||||
}: LLMPopoverProps) {
|
||||
const llmProviders = llmManager.llmProviders;
|
||||
const isLoadingProviders = llmManager.isLoadingProviders;
|
||||
|
||||
// Use selectedLlms from llmManager directly
|
||||
const selectedModels = llmManager.selectedLlms;
|
||||
const onSelectionChange = llmManager.updateSelectedLlms;
|
||||
|
||||
const [open, setOpen] = useState(false);
|
||||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const { user } = useUser();
|
||||
|
||||
// Helper to check if a model is selected
|
||||
const isModelSelected = useCallback(
|
||||
(option: LLMOption) => {
|
||||
return selectedModels.some(
|
||||
(m) =>
|
||||
m.modelName === option.modelName && m.provider === option.provider
|
||||
);
|
||||
},
|
||||
[selectedModels]
|
||||
);
|
||||
|
||||
// Check if we're at max selection
|
||||
const isAtMaxSelection = selectedModels.length >= maxSelection;
|
||||
|
||||
const [localTemperature, setLocalTemperature] = useState(
|
||||
llmManager.temperature ?? 0.5
|
||||
);
|
||||
@@ -245,6 +271,35 @@ export default function LLMPopover({
|
||||
return currentModel;
|
||||
}, [llmProviders, llmManager.currentLlm.modelName]);
|
||||
|
||||
// Get display name for a model (used in multi-select mode)
|
||||
const getModelDisplayName = useCallback(
|
||||
(model: LlmDescriptor) => {
|
||||
if (!llmProviders) return model.modelName;
|
||||
|
||||
for (const provider of llmProviders) {
|
||||
const config = provider.model_configurations.find(
|
||||
(m) => m.name === model.modelName
|
||||
);
|
||||
if (config) {
|
||||
return config.display_name || config.name;
|
||||
}
|
||||
}
|
||||
return model.modelName;
|
||||
},
|
||||
[llmProviders]
|
||||
);
|
||||
|
||||
// Get trigger display text for multi-select mode
|
||||
const multiSelectTriggerText = useMemo(() => {
|
||||
if (selectedModels.length === 0) {
|
||||
return "Select models";
|
||||
}
|
||||
if (selectedModels.length === 1) {
|
||||
return getModelDisplayName(selectedModels[0]!);
|
||||
}
|
||||
return `${selectedModels.length} models selected`;
|
||||
}, [selectedModels, getModelDisplayName]);
|
||||
|
||||
// Determine which group the current model belongs to (for auto-expand)
|
||||
const currentGroupKey = useMemo(() => {
|
||||
const currentModel = llmManager.currentLlm.modelName;
|
||||
@@ -318,13 +373,55 @@ export default function LLMPopover({
|
||||
};
|
||||
|
||||
const handleSelectModel = (option: LLMOption) => {
|
||||
llmManager.updateCurrentLlm({
|
||||
const newDescriptor: LlmDescriptor = {
|
||||
modelName: option.modelName,
|
||||
provider: option.provider,
|
||||
name: option.name,
|
||||
} as LlmDescriptor);
|
||||
} as LlmDescriptor;
|
||||
|
||||
if (singleSelectMode) {
|
||||
// Single-select mode: just call onSelect and close (for actions like regeneration)
|
||||
onSelect?.(
|
||||
structureValue(option.name, option.provider, option.modelName)
|
||||
);
|
||||
setOpen(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Multi-select mode: toggle selection
|
||||
const isSelected = isModelSelected(option);
|
||||
|
||||
if (isSelected) {
|
||||
// Don't allow deselecting if at minimum
|
||||
if (selectedModels.length <= minSelection) {
|
||||
return;
|
||||
}
|
||||
const newSelection = selectedModels.filter(
|
||||
(m) =>
|
||||
!(m.modelName === option.modelName && m.provider === option.provider)
|
||||
);
|
||||
onSelectionChange(newSelection);
|
||||
} else {
|
||||
// Don't allow selecting if at maximum
|
||||
if (selectedModels.length >= maxSelection) {
|
||||
return;
|
||||
}
|
||||
onSelectionChange([...selectedModels, newDescriptor]);
|
||||
}
|
||||
|
||||
// Also call legacy onSelect for backwards compatibility
|
||||
onSelect?.(structureValue(option.name, option.provider, option.modelName));
|
||||
setOpen(false);
|
||||
};
|
||||
|
||||
// Handle removing a model from selection (for chips)
|
||||
const handleRemoveModel = (model: LlmDescriptor) => {
|
||||
if (selectedModels.length <= minSelection) {
|
||||
return;
|
||||
}
|
||||
const newSelection = selectedModels.filter(
|
||||
(m) => !(m.modelName === model.modelName && m.provider === model.provider)
|
||||
);
|
||||
onSelectionChange?.(newSelection);
|
||||
};
|
||||
|
||||
return (
|
||||
@@ -335,10 +432,15 @@ export default function LLMPopover({
|
||||
leftIcon={
|
||||
folded
|
||||
? SvgRefreshCw
|
||||
: getProviderIcon(
|
||||
llmManager.currentLlm.provider,
|
||||
llmManager.currentLlm.modelName
|
||||
)
|
||||
: selectedModels.length > 0
|
||||
? getProviderIcon(
|
||||
selectedModels[0]!.provider,
|
||||
selectedModels[0]!.modelName
|
||||
)
|
||||
: getProviderIcon(
|
||||
llmManager.currentLlm.provider,
|
||||
llmManager.currentLlm.modelName
|
||||
)
|
||||
}
|
||||
onClick={() => setOpen(true)}
|
||||
transient={open}
|
||||
@@ -347,12 +449,79 @@ export default function LLMPopover({
|
||||
disabled={disabled}
|
||||
className={disabled ? "bg-transparent" : ""}
|
||||
>
|
||||
{currentLlmDisplayName}
|
||||
{multiSelectTriggerText}
|
||||
</SelectButton>
|
||||
</div>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent side="top" align="end" className="w-[280px] p-1">
|
||||
<div className="flex flex-col gap-1">
|
||||
{/* Selection Summary (hidden in single-select mode) */}
|
||||
{!singleSelectMode && selectedModels.length > 0 && (
|
||||
<div className="px-2 py-2 border-b border-border-02">
|
||||
<div className="flex items-center justify-between mb-2">
|
||||
<Text secondaryBody text03>
|
||||
Selected ({selectedModels.length}/{maxSelection})
|
||||
</Text>
|
||||
{isAtMaxSelection && (
|
||||
<Text
|
||||
secondaryBody
|
||||
className="text-amber-600 dark:text-amber-400 text-xs"
|
||||
>
|
||||
Max reached
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
{selectedModels.length === 0 ? (
|
||||
<Text secondaryBody text03 className="text-text-04">
|
||||
Select up to {maxSelection} models
|
||||
</Text>
|
||||
) : (
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{selectedModels.map((model) => (
|
||||
<div
|
||||
key={`${model.provider}-${model.modelName}`}
|
||||
className="inline-flex items-center gap-1 px-2 py-1 bg-action-link-01 text-text-inverse rounded-full text-xs"
|
||||
>
|
||||
<span className="max-w-[100px] truncate">
|
||||
{getModelDisplayName(model)}
|
||||
</span>
|
||||
{selectedModels.length > minSelection && (
|
||||
<button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleRemoveModel(model);
|
||||
}}
|
||||
className="ml-0.5 hover:bg-white/20 rounded-full p-0.5 transition-colors"
|
||||
title="Remove model"
|
||||
>
|
||||
<svg
|
||||
className="w-3 h-3"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{selectedModels.length > 0 &&
|
||||
selectedModels.length < minSelection && (
|
||||
<Text secondaryBody className="text-text-04 text-xs mt-1">
|
||||
At least {minSelection} model required
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Search Input */}
|
||||
<InputTypeIn
|
||||
ref={searchInputRef}
|
||||
@@ -439,11 +608,18 @@ export default function LLMPopover({
|
||||
<div className="flex flex-col gap-1">
|
||||
{group.options.map((option) => {
|
||||
// Match by both modelName AND provider to handle same model name across providers
|
||||
const isSelected =
|
||||
option.modelName ===
|
||||
llmManager.currentLlm.modelName &&
|
||||
option.provider ===
|
||||
llmManager.currentLlm.provider;
|
||||
const isSelected = isModelSelected(option);
|
||||
|
||||
// In single-select mode, check if this is the current model being replaced
|
||||
const isCurrentModel =
|
||||
singleSelectMode &&
|
||||
currentModelName === option.modelName;
|
||||
|
||||
// Disable unselected items when at max (only in multi-select mode)
|
||||
const isDisabled =
|
||||
!singleSelectMode &&
|
||||
!isSelected &&
|
||||
isAtMaxSelection;
|
||||
|
||||
// Build description with version info
|
||||
const description =
|
||||
@@ -456,18 +632,42 @@ export default function LLMPopover({
|
||||
<div
|
||||
key={`${option.name}-${option.modelName}`}
|
||||
ref={
|
||||
isSelected ? selectedItemRef : undefined
|
||||
isSelected || isCurrentModel
|
||||
? selectedItemRef
|
||||
: undefined
|
||||
}
|
||||
title={
|
||||
isDisabled
|
||||
? `Maximum of ${maxSelection} models selected`
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
<LineItem
|
||||
selected={isSelected}
|
||||
selected={
|
||||
singleSelectMode
|
||||
? isCurrentModel
|
||||
: isSelected
|
||||
}
|
||||
description={description}
|
||||
onClick={() =>
|
||||
!isDisabled &&
|
||||
handleSelectModel(option)
|
||||
}
|
||||
className="pl-7"
|
||||
className={`pl-7 ${
|
||||
isDisabled
|
||||
? "opacity-50 cursor-not-allowed"
|
||||
: ""
|
||||
}`}
|
||||
rightChildren={
|
||||
isSelected ? (
|
||||
// In single-select mode: show "Current" badge for the model being replaced
|
||||
// In multi-select mode: show checkmark for selected models
|
||||
singleSelectMode ? (
|
||||
isCurrentModel ? (
|
||||
<span className="text-[10px] font-medium text-text-03 bg-background-emphasis px-1.5 py-0.5 rounded">
|
||||
Current
|
||||
</span>
|
||||
) : null
|
||||
) : isSelected ? (
|
||||
<SvgCheck className="h-4 w-4 stroke-action-link-05 shrink-0" />
|
||||
) : null
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user