mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-25 01:22:45 +00:00
Compare commits
21 Commits
multi-mode
...
dane/max-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf9bd7e511 | ||
|
|
b5dd17a371 | ||
|
|
d62d0c1864 | ||
|
|
2c92742c62 | ||
|
|
1e1402e4f1 | ||
|
|
440818a082 | ||
|
|
bd9f40d1c1 | ||
|
|
c85e090c13 | ||
|
|
d72df59063 | ||
|
|
867442bc54 | ||
|
|
f752761e46 | ||
|
|
a760d1cf33 | ||
|
|
acffd55ce4 | ||
|
|
3a4be4a7d9 | ||
|
|
7c0e7eddbd | ||
|
|
2e5763c9ab | ||
|
|
5c45345521 | ||
|
|
0665f31a7d | ||
|
|
17442ed2d0 | ||
|
|
5b0c2f3c18 | ||
|
|
cff564eb6a |
@@ -1,36 +0,0 @@
|
||||
"""add preferred_response_id and model_display_name to chat_message
|
||||
|
||||
Revision ID: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-22
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3f8b2c1d4e5"
|
||||
down_revision = "b728689f45b1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"preferred_response_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("chat_message.id"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("model_display_name", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "model_display_name")
|
||||
op.drop_column("chat_message", "preferred_response_id")
|
||||
@@ -8,7 +8,6 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -36,13 +35,7 @@ class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
AnswerStreamPart = (
|
||||
Packet
|
||||
| MessageResponseIDInfo
|
||||
| MultiModelMessageResponseIDInfo
|
||||
| StreamingError
|
||||
| CreateChatSessionID
|
||||
)
|
||||
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
|
||||
|
||||
AnswerStream = Iterator[AnswerStreamPart]
|
||||
|
||||
|
||||
@@ -4,11 +4,9 @@ An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import io
|
||||
import queue
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
@@ -30,7 +28,6 @@ from onyx.chat.compression import calculate_total_history_tokens
|
||||
from onyx.chat.compression import compress_chat_history
|
||||
from onyx.chat.compression import find_summary_for_branch
|
||||
from onyx.chat.compression import get_compression_params
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.llm_loop import EmptyLLMResponseError
|
||||
from onyx.chat.llm_loop import run_llm_loop
|
||||
@@ -62,8 +59,6 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -82,20 +77,16 @@ from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
@@ -1006,568 +997,6 @@ def handle_stream_message_objects(
|
||||
logger.exception("Error in setting processing status")
|
||||
|
||||
|
||||
def _build_model_display_name(override: LLMOverride) -> str:
|
||||
"""Build a human-readable display name from an LLM override."""
|
||||
if override.display_name:
|
||||
return override.display_name
|
||||
if override.model_version:
|
||||
return override.model_version
|
||||
if override.model_provider:
|
||||
return override.model_provider
|
||||
return "unknown"
|
||||
|
||||
|
||||
# Sentinel placed on the merged queue when a model thread finishes.
|
||||
_MODEL_DONE = object()
|
||||
|
||||
|
||||
class _ModelIndexEmitter(Emitter):
|
||||
"""Emitter that tags packets with model_index and forwards directly to a shared queue.
|
||||
|
||||
Unlike the standard Emitter (which accumulates in a local bus), this puts
|
||||
packets into the shared merged_queue in real-time as they're emitted. This
|
||||
enables true parallel streaming — packets from multiple models interleave
|
||||
on the wire instead of arriving in bursts after each model completes.
|
||||
"""
|
||||
|
||||
def __init__(self, model_idx: int, merged_queue: queue.Queue) -> None:
|
||||
super().__init__(queue.Queue()) # bus exists for compat, unused
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
tagged_placement = Placement(
|
||||
turn_index=packet.placement.turn_index if packet.placement else 0,
|
||||
tab_index=packet.placement.tab_index if packet.placement else 0,
|
||||
sub_turn_index=(
|
||||
packet.placement.sub_turn_index if packet.placement else None
|
||||
),
|
||||
model_index=self._model_idx,
|
||||
)
|
||||
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
|
||||
self._merged_queue.put((self._model_idx, tagged_packet))
|
||||
|
||||
|
||||
def run_multi_model_stream(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
llm_overrides: list[LLMOverride],
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
) -> AnswerStream:
|
||||
# TODO: The setup logic below (session resolution through tool construction)
|
||||
# is duplicated from handle_stream_message_objects. Extract into a shared
|
||||
# _ChatStreamContext dataclass + _prepare_chat_stream_context() factory so
|
||||
# both paths call the same setup code. Tracked as follow-up refactor.
|
||||
"""Run 2-3 LLMs in parallel and yield their packets tagged with model_index.
|
||||
|
||||
Resource management:
|
||||
- Each model thread gets its OWN db_session (SQLAlchemy sessions are not thread-safe)
|
||||
- The caller's db_session is used only for setup (before threads launch) and
|
||||
completion callbacks (after threads finish)
|
||||
- ThreadPoolExecutor is bounded to len(overrides) workers
|
||||
- All threads are joined in the finally block regardless of success/failure
|
||||
- Queue-based merging avoids busy-waiting
|
||||
"""
|
||||
n_models = len(llm_overrides)
|
||||
if n_models < 2 or n_models > 3:
|
||||
raise ValueError(f"Multi-model requires 2-3 overrides, got {n_models}")
|
||||
if new_msg_req.deep_research:
|
||||
raise ValueError("Multi-model is not supported with deep research")
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
cache: CacheBackend | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
|
||||
user_id = user.id
|
||||
if user.is_anonymous:
|
||||
llm_user_identifier = "anonymous_user"
|
||||
else:
|
||||
llm_user_identifier = user.email or str(user_id)
|
||||
|
||||
try:
|
||||
# ── Session setup (same as single-model path) ──────────────────
|
||||
if not new_msg_req.chat_session_id:
|
||||
if not new_msg_req.chat_session_info:
|
||||
raise RuntimeError(
|
||||
"Must specify a chat session id or chat session info"
|
||||
)
|
||||
chat_session = create_chat_session_from_request(
|
||||
chat_session_request=new_msg_req.chat_session_info,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield CreateChatSessionID(chat_session_id=chat_session.id)
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
persona = chat_session.persona
|
||||
message_text = new_msg_req.message
|
||||
|
||||
# ── Build N LLM instances and validate costs ───────────────────
|
||||
llms: list[LLM] = []
|
||||
model_display_names: list[str] = []
|
||||
for override in llm_overrides:
|
||||
llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
llm_override=override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
check_llm_cost_limit_for_provider(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
llm_provider_api_key=llm.config.api_key,
|
||||
)
|
||||
llms.append(llm)
|
||||
model_display_names.append(_build_model_display_name(override))
|
||||
|
||||
# Use first LLM for token counting (context window is checked per-model
|
||||
# but token counting is model-agnostic enough for setup purposes)
|
||||
token_counter = get_llm_token_counter(llms[0])
|
||||
|
||||
verify_user_files(
|
||||
user_files=new_msg_req.file_descriptors,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
project_id=chat_session.project_id,
|
||||
)
|
||||
|
||||
# ── Chat history chain (shared across all models) ──────────────
|
||||
chat_history = create_chat_history_chain(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
|
||||
parent_message = chat_history[-1] if chat_history else root_message
|
||||
elif (
|
||||
new_msg_req.parent_message_id is None
|
||||
or new_msg_req.parent_message_id == root_message.id
|
||||
):
|
||||
parent_message = root_message
|
||||
chat_history = []
|
||||
else:
|
||||
parent_message = None
|
||||
for i in range(len(chat_history) - 1, -1, -1):
|
||||
if chat_history[i].id == new_msg_req.parent_message_id:
|
||||
parent_message = chat_history[i]
|
||||
chat_history = chat_history[: i + 1]
|
||||
break
|
||||
|
||||
if parent_message is None:
|
||||
raise ValueError(
|
||||
"The new message sent is not on the latest mainline of messages"
|
||||
)
|
||||
|
||||
if parent_message.message_type == MessageType.USER:
|
||||
user_message = parent_message
|
||||
else:
|
||||
user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=parent_message,
|
||||
message=message_text,
|
||||
token_count=token_counter(message_text),
|
||||
message_type=MessageType.USER,
|
||||
files=new_msg_req.file_descriptors,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
chat_history.append(user_message)
|
||||
|
||||
available_files = _collect_available_file_ids(
|
||||
chat_history=chat_history,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
summary_message = find_summary_for_branch(db_session, chat_history)
|
||||
summarized_file_metadata: dict[str, FileToolMetadata] = {}
|
||||
if summary_message and summary_message.last_summarized_message_id:
|
||||
cutoff_id = summary_message.last_summarized_message_id
|
||||
for msg in chat_history:
|
||||
if msg.id > cutoff_id or not msg.files:
|
||||
continue
|
||||
for fd in msg.files:
|
||||
file_id = fd.get("id")
|
||||
if not file_id:
|
||||
continue
|
||||
summarized_file_metadata[file_id] = FileToolMetadata(
|
||||
file_id=file_id,
|
||||
filename=fd.get("name") or "unknown",
|
||||
approx_char_count=0,
|
||||
)
|
||||
chat_history = [m for m in chat_history if m.id > cutoff_id]
|
||||
|
||||
user_memory_context = get_memories(user, db_session)
|
||||
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
prompt_memory_context = (
|
||||
user_memory_context
|
||||
if user.use_memories
|
||||
else user_memory_context.without_memories()
|
||||
)
|
||||
|
||||
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
|
||||
custom_agent_prompt or ""
|
||||
)
|
||||
|
||||
reserved_token_count = calculate_reserved_tokens(
|
||||
db_session=db_session,
|
||||
persona_system_prompt=max_reserved_system_prompt_tokens_str,
|
||||
token_counter=token_counter,
|
||||
files=new_msg_req.file_descriptors,
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
context_user_files = resolve_context_user_files(
|
||||
persona=persona,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Use the smallest context window across all models for safety
|
||||
min_context_window = min(llm.config.max_input_tokens for llm in llms)
|
||||
|
||||
extracted_context_files = extract_context_files(
|
||||
user_files=context_user_files,
|
||||
llm_max_context_window=min_context_window,
|
||||
reserved_token_count=reserved_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
search_params = determine_search_params(
|
||||
persona_id=persona.id,
|
||||
project_id=chat_session.project_id,
|
||||
extracted_context_files=extracted_context_files,
|
||||
)
|
||||
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
search_tool_id = next(
|
||||
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
|
||||
None,
|
||||
)
|
||||
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
if (
|
||||
search_params.search_usage == SearchToolUsage.DISABLED
|
||||
and forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
):
|
||||
forced_tool_id = None
|
||||
|
||||
files = load_all_chat_files(chat_history, db_session)
|
||||
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
|
||||
|
||||
# ── Reserve N assistant message IDs ────────────────────────────
|
||||
reserved_messages = reserve_multi_model_message_ids(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=user_message.id,
|
||||
model_display_names=model_display_names,
|
||||
)
|
||||
|
||||
yield MultiModelMessageResponseIDInfo(
|
||||
user_message_id=user_message.id,
|
||||
reserved_assistant_message_ids=[m.id for m in reserved_messages],
|
||||
model_names=model_display_names,
|
||||
)
|
||||
|
||||
has_file_reader_tool = any(
|
||||
tool.in_code_tool_id == "file_reader" for tool in all_tools
|
||||
)
|
||||
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
context_image_files=extracted_context_files.image_files,
|
||||
additional_context=new_msg_req.additional_context,
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
)
|
||||
simple_chat_history = chat_history_result.simple_messages
|
||||
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] = (
|
||||
chat_history_result.all_injected_file_metadata
|
||||
if has_file_reader_tool
|
||||
else {}
|
||||
)
|
||||
if summarized_file_metadata:
|
||||
for fid, meta in summarized_file_metadata.items():
|
||||
all_injected_file_metadata.setdefault(fid, meta)
|
||||
|
||||
if summary_message is not None:
|
||||
summary_simple = ChatMessageSimple(
|
||||
message=summary_message.message,
|
||||
token_count=summary_message.token_count,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
simple_chat_history.insert(0, summary_simple)
|
||||
|
||||
# ── Stop signal and processing status ──────────────────────────
|
||||
cache = get_cache_backend()
|
||||
reset_cancel_status(chat_session.id, cache)
|
||||
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, cache)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
value=True,
|
||||
)
|
||||
|
||||
# Release the main session's read transaction before the long stream
|
||||
db_session.commit()
|
||||
|
||||
# ── Parallel model execution ───────────────────────────────────
|
||||
# Each model thread writes tagged packets to this shared queue.
|
||||
# Sentinel _MODEL_DONE signals that a thread finished.
|
||||
merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = (
|
||||
queue.Queue()
|
||||
)
|
||||
|
||||
# Track per-model state containers for completion callbacks
|
||||
state_containers: list[ChatStateContainer] = [
|
||||
ChatStateContainer() for _ in range(n_models)
|
||||
]
|
||||
# Track which models completed successfully (for completion callbacks)
|
||||
model_succeeded: list[bool] = [False] * n_models
|
||||
|
||||
user_identity = LLMUserIdentity(
|
||||
user_id=llm_user_identifier,
|
||||
session_id=str(chat_session.id),
|
||||
)
|
||||
|
||||
def _run_model(model_idx: int) -> None:
|
||||
"""Run a single model in a worker thread.
|
||||
|
||||
Uses _ModelIndexEmitter so packets flow directly to merged_queue
|
||||
in real-time (not batched after completion). This enables true
|
||||
parallel streaming where both models' tokens interleave on the wire.
|
||||
|
||||
DB access: tools may need a session during execution (e.g., search
|
||||
tool). Each thread creates its own session via context manager.
|
||||
"""
|
||||
model_emitter = _ModelIndexEmitter(model_idx, merged_queue)
|
||||
sc = state_containers[model_idx]
|
||||
model_llm = llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each model thread gets its own DB session for tool execution.
|
||||
# The session is scoped to the thread and closed when done.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
# Construct tools per-thread with thread-local DB session
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id_filter=search_params.project_id_filter,
|
||||
persona_id_filter=search_params.persona_id_filter,
|
||||
bypass_acl=False,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
persona, new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session.id,
|
||||
message_id=user_message.id,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=available_files.user_file_ids,
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=search_params.search_usage,
|
||||
)
|
||||
model_tools: list[Tool] = []
|
||||
for tool_list in thread_tool_dict.values():
|
||||
model_tools.extend(tool_list)
|
||||
|
||||
# Run the LLM loop — this blocks until the model finishes.
|
||||
# Packets flow to merged_queue in real-time via the emitter.
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
context_files=extracted_context_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_files=chat_files_for_tools,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
except Exception as e:
|
||||
merged_queue.put((model_idx, e))
|
||||
|
||||
finally:
|
||||
merged_queue.put((model_idx, _MODEL_DONE))
|
||||
|
||||
# Launch model threads via ThreadPoolExecutor (bounded, context-propagating)
|
||||
executor = ThreadPoolExecutor(
|
||||
max_workers=n_models,
|
||||
thread_name_prefix="multi-model",
|
||||
)
|
||||
futures = []
|
||||
try:
|
||||
for i in range(n_models):
|
||||
futures.append(executor.submit(_run_model, i))
|
||||
|
||||
# ── Main thread: merge and yield packets ───────────────────
|
||||
models_remaining = n_models
|
||||
while models_remaining > 0:
|
||||
try:
|
||||
model_idx, item = merged_queue.get(timeout=0.3)
|
||||
except queue.Empty:
|
||||
# Check cancellation during idle periods
|
||||
if not check_is_connected():
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
return
|
||||
continue
|
||||
|
||||
if item is _MODEL_DONE:
|
||||
models_remaining -= 1
|
||||
continue
|
||||
|
||||
if isinstance(item, Exception):
|
||||
# Yield error as a tagged StreamingError packet
|
||||
error_msg = str(item)
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(type(item), item, item.__traceback__)
|
||||
)
|
||||
# Redact API keys from error messages
|
||||
model_llm = llms[model_idx]
|
||||
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
stack_trace = stack_trace.replace(
|
||||
model_llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="MODEL_ERROR",
|
||||
is_retryable=True,
|
||||
details={
|
||||
"model": model_llm.config.model_name,
|
||||
"provider": model_llm.config.model_provider,
|
||||
"model_index": model_idx,
|
||||
},
|
||||
)
|
||||
models_remaining -= 1
|
||||
continue
|
||||
|
||||
if isinstance(item, Packet):
|
||||
# Packet is already tagged with model_index by _ModelIndexEmitter
|
||||
yield item
|
||||
|
||||
# ── Completion: save each successful model's response ──────
|
||||
# Run completion callbacks on the main thread using the main
|
||||
# session. This is safe because all worker threads have exited
|
||||
# by this point (merged_queue fully drained).
|
||||
for i in range(n_models):
|
||||
if not model_succeeded[i]:
|
||||
continue
|
||||
try:
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_containers[i],
|
||||
is_connected=check_is_connected,
|
||||
db_session=db_session,
|
||||
assistant_message=reserved_messages[i],
|
||||
llm=llms[i],
|
||||
reserved_tokens=reserved_token_count,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed completion for model {i} "
|
||||
f"({model_display_names[i]})"
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop", stop_reason="complete"),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Ensure all threads are cleaned up regardless of how we exit
|
||||
executor.shutdown(wait=True, cancel_futures=True)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process multi-model chat message.")
|
||||
yield StreamingError(
|
||||
error=str(e),
|
||||
error_code="VALIDATION_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed multi-model chat: {e}")
|
||||
stack_trace = traceback.format_exc()
|
||||
yield StreamingError(
|
||||
error=str(e),
|
||||
stack_trace=stack_trace,
|
||||
error_code="MULTI_MODEL_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
finally:
|
||||
try:
|
||||
if cache is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
cache=cache,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error clearing processing status")
|
||||
|
||||
|
||||
def llm_loop_completion_handle(
|
||||
state_container: ChatStateContainer,
|
||||
is_connected: Callable[[], bool],
|
||||
|
||||
@@ -787,6 +787,10 @@ MINI_CHUNK_SIZE = 150
|
||||
# This is the number of regular chunks per large chunk
|
||||
LARGE_CHUNK_RATIO = 4
|
||||
|
||||
# The maximum number of chunks that can be held for 1 document processing batch
|
||||
# The purpose of this is to set an upper bound on memory usage
|
||||
MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000)
|
||||
|
||||
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
|
||||
@@ -602,79 +602,6 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model user message.
|
||||
|
||||
Validates that the user message is a USER type and that the preferred
|
||||
assistant message is a direct child of that user message.
|
||||
"""
|
||||
user_msg = db_session.query(ChatMessage).get(user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.query(ChatMessage).get(preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -897,8 +824,6 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -2645,15 +2645,6 @@ class ChatMessage(Base):
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# For multi-model turns: the user message points to which assistant response
|
||||
# was selected as the preferred one to continue the conversation with.
|
||||
preferred_response_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_message.id"), nullable=True
|
||||
)
|
||||
|
||||
# The display name of the model that generated this assistant message
|
||||
model_display_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# What does this message contain
|
||||
reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
@@ -2721,12 +2712,6 @@ class ChatMessage(Base):
|
||||
remote_side="ChatMessage.id",
|
||||
)
|
||||
|
||||
preferred_response: Mapped["ChatMessage | None"] = relationship(
|
||||
"ChatMessage",
|
||||
foreign_keys=[preferred_response_id],
|
||||
remote_side="ChatMessage.id",
|
||||
)
|
||||
|
||||
# Chat messages only need to know their immediate tool call children
|
||||
# If there are nested tool calls, they are stored in the tool_call_children relationship.
|
||||
tool_calls: Mapped[list["ToolCall"] | None] = relationship(
|
||||
|
||||
@@ -5,6 +5,7 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
|
||||
out against a nonexistent Vespa/OpenSearch instance.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from onyx.context.search.models import IndexFilters
|
||||
@@ -66,7 +67,7 @@ class DisabledDocumentIndex(DocumentIndex):
|
||||
# ------------------------------------------------------------------
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
index_batch_params: IndexBatchParams, # noqa: ARG002
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
@@ -206,7 +207,7 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -226,8 +227,8 @@ class Indexable(abc.ABC):
|
||||
it is done automatically outside of this code.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document
|
||||
index.
|
||||
- chunks: Document chunks with all of the information needed for
|
||||
indexing to the document index.
|
||||
- tenant_id: The tenant id of the user whose chunks are being indexed
|
||||
- large_chunks_enabled: Whether large chunks are enabled
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -209,10 +210,10 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
This is often a batch operation including chunks from multiple
|
||||
documents.
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -350,7 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -646,10 +647,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata, # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
Groups chunks by document ID and for each document, deletes existing
|
||||
chunks and indexes the new chunks in bulk.
|
||||
@@ -672,29 +673,34 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document is newly indexed or had already existed and was just
|
||||
updated.
|
||||
"""
|
||||
# Group chunks by document ID.
|
||||
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
|
||||
list
|
||||
total_chunks = sum(
|
||||
cc.new_chunk_cnt
|
||||
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
|
||||
)
|
||||
for chunk in chunks:
|
||||
doc_id_to_chunks[chunk.source_document.id].append(chunk)
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
|
||||
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
|
||||
f"documents for index {self._index_name}."
|
||||
)
|
||||
|
||||
document_indexing_results: list[DocumentInsertionRecord] = []
|
||||
# Try to index per-document.
|
||||
for _, chunks in doc_id_to_chunks.items():
|
||||
deleted_doc_ids: set[str] = set()
|
||||
# Buffer chunks per document as they arrive from the iterable.
|
||||
# When the document ID changes flush the buffered chunks.
|
||||
current_doc_id: str | None = None
|
||||
current_chunks: list[DocMetadataAwareIndexChunk] = []
|
||||
|
||||
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
|
||||
assert len(doc_chunks) > 0, "doc_chunks is empty"
|
||||
|
||||
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
|
||||
# Do this before deleting existing chunks to reduce the amount of
|
||||
# time the document index has no content for a given document, and
|
||||
# to reduce the chance of entering a state where we delete chunks,
|
||||
# then some error happens, and never successfully index new chunks.
|
||||
# Since we are doing this in batches, an error occurring midway
|
||||
# can result in a state where chunks are deleted and not all the
|
||||
# new chunks have been indexed.
|
||||
chunk_batch: list[DocumentChunk] = [
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk)
|
||||
for chunk in doc_chunks
|
||||
]
|
||||
onyx_document: Document = chunks[0].source_document
|
||||
onyx_document: Document = doc_chunks[0].source_document
|
||||
# First delete the doc's chunks from the index. This is so that
|
||||
# there are no dangling chunks in the index, in the event that the
|
||||
# new document's content contains fewer chunks than the previous
|
||||
@@ -703,22 +709,43 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# if the chunk count has actually decreased. This assumes that
|
||||
# overlapping chunks are perfectly overwritten. If we can't
|
||||
# guarantee that then we need the code as-is.
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed.
|
||||
document_insertion_record = DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
if onyx_document.id not in deleted_doc_ids:
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
deleted_doc_ids.add(onyx_document.id)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed. We record the result before bulk_index_documents
|
||||
# runs. If indexing raises, this entire result list is discarded
|
||||
# by the caller's retry logic, so early recording is safe.
|
||||
document_indexing_results.append(
|
||||
DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
document_indexing_results.append(document_insertion_record)
|
||||
|
||||
for chunk in chunks:
|
||||
doc_id = chunk.source_document.id
|
||||
if doc_id != current_doc_id:
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
current_doc_id = doc_id
|
||||
current_chunks = [chunk]
|
||||
elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH:
|
||||
_flush_chunks(current_chunks)
|
||||
current_chunks = [chunk]
|
||||
else:
|
||||
current_chunks.append(chunk)
|
||||
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
|
||||
return document_indexing_results
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
import time
|
||||
import urllib
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -461,7 +462,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -8,6 +10,7 @@ import httpx
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
|
||||
from onyx.configs.app_configs import RERANK_COUNT
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
@@ -318,7 +321,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
|
||||
@@ -338,22 +341,31 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
# Vespa has restrictions on valid characters, yet document IDs come from
|
||||
# external w.r.t. this class. We need to sanitize them.
|
||||
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
|
||||
clean_chunk_id_copy(chunk) for chunk in chunks
|
||||
]
|
||||
assert len(cleaned_chunks) == len(
|
||||
chunks
|
||||
), "Bug: Cleaned chunks and input chunks have different lengths."
|
||||
#
|
||||
# Instead of materializing all cleaned chunks upfront, we stream them
|
||||
# through a generator that cleans IDs and builds the original-ID mapping
|
||||
# incrementally as chunks flow into Vespa.
|
||||
def _clean_and_track(
|
||||
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
|
||||
id_map: dict[str, str],
|
||||
seen_ids: set[str],
|
||||
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
|
||||
"""Cleans chunk IDs and builds the original-ID mapping
|
||||
incrementally as chunks flow through, avoiding a separate
|
||||
materialization pass."""
|
||||
for chunk in chunks_iter:
|
||||
original_id = chunk.source_document.id
|
||||
cleaned = clean_chunk_id_copy(chunk)
|
||||
cleaned_id = cleaned.source_document.id
|
||||
# Needed so the final DocumentInsertionRecord returned can have
|
||||
# the original document ID. cleaned_chunks might not contain IDs
|
||||
# exactly as callers supplied them.
|
||||
id_map[cleaned_id] = original_id
|
||||
seen_ids.add(cleaned_id)
|
||||
yield cleaned
|
||||
|
||||
# Needed so the final DocumentInsertionRecord returned can have the
|
||||
# original document ID. cleaned_chunks might not contain IDs exactly as
|
||||
# callers supplied them.
|
||||
new_document_id_to_original_document_id: dict[str, str] = dict()
|
||||
for i, cleaned_chunk in enumerate(cleaned_chunks):
|
||||
old_chunk = chunks[i]
|
||||
new_document_id_to_original_document_id[
|
||||
cleaned_chunk.source_document.id
|
||||
] = old_chunk.source_document.id
|
||||
new_document_id_to_original_document_id: dict[str, str] = {}
|
||||
all_cleaned_doc_ids: set[str] = set()
|
||||
|
||||
existing_docs: set[str] = set()
|
||||
|
||||
@@ -409,8 +421,16 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
# Insert new Vespa documents.
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
# Insert new Vespa documents, streaming through the cleaning
|
||||
# pipeline so chunks are never fully materialized.
|
||||
cleaned_chunks = _clean_and_track(
|
||||
chunks,
|
||||
new_document_id_to_original_document_id,
|
||||
all_cleaned_doc_ids,
|
||||
)
|
||||
for chunk_batch in batch_generator(
|
||||
cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH)
|
||||
):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
index_name=self._index_name,
|
||||
@@ -419,10 +439,6 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
all_cleaned_doc_ids: set[str] = {
|
||||
chunk.source_document.id for chunk in cleaned_chunks
|
||||
}
|
||||
|
||||
return [
|
||||
DocumentInsertionRecord(
|
||||
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
|
||||
|
||||
@@ -11,7 +11,6 @@ class LLMOverride(BaseModel):
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
display_name: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.process_message import run_multi_model_stream
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
@@ -47,7 +46,6 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -83,7 +81,6 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -573,38 +570,6 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in run_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=chat_message_req.llm_overrides, # type: ignore[arg-type]
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -695,26 +660,6 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
_user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
try:
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -41,16 +41,6 @@ class MessageResponseIDInfo(BaseModel):
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class MultiModelMessageResponseIDInfo(BaseModel):
|
||||
"""Sent at the start of a multi-model streaming response.
|
||||
Contains the user message ID and the reserved assistant message IDs
|
||||
for each model being run in parallel."""
|
||||
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_ids: list[int]
|
||||
model_names: list[str]
|
||||
|
||||
|
||||
class SourceTag(Tag):
|
||||
source: DocumentSource
|
||||
|
||||
@@ -96,9 +86,6 @@ class SendMessageRequest(BaseModel):
|
||||
message: str
|
||||
|
||||
llm_override: LLMOverride | None = None
|
||||
# For multi-model mode: up to 3 LLM overrides to run in parallel.
|
||||
# When provided with >1 entry, triggers multi-model streaming.
|
||||
llm_overrides: list[LLMOverride] | None = None
|
||||
# Test-only override for deterministic LiteLLM mock responses.
|
||||
mock_llm_response: str | None = None
|
||||
|
||||
@@ -224,8 +211,6 @@ class ChatMessageDetail(BaseModel):
|
||||
error: str | None = None
|
||||
current_feedback: str | None = None # "like" | "dislike" | null
|
||||
processing_duration_seconds: float | None = None
|
||||
preferred_response_id: int | None = None
|
||||
model_display_name: str | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
@@ -233,11 +218,6 @@ class ChatMessageDetail(BaseModel):
|
||||
return initial_dict
|
||||
|
||||
|
||||
class SetPreferredResponseRequest(BaseModel):
|
||||
user_message_id: int
|
||||
preferred_response_id: int
|
||||
|
||||
|
||||
class ChatSessionDetailResponse(BaseModel):
|
||||
chat_session_id: UUID
|
||||
description: str | None
|
||||
|
||||
@@ -8,5 +8,3 @@ class Placement(BaseModel):
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -6,6 +6,7 @@ These tests assume Vespa and OpenSearch are running.
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@@ -21,6 +22,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
|
||||
from tests.external_dependency_unit.document_index.conftest import make_chunk
|
||||
@@ -201,3 +203,25 @@ class TestDocumentIndexNew:
|
||||
assert len(result_map) == 2
|
||||
assert result_map[existing_doc] is True
|
||||
assert result_map[new_doc] is False
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndexNew],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
|
||||
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk(doc_id, chunk_id=i)
|
||||
|
||||
results = document_index.index(
|
||||
chunks=chunk_gen(), indexing_metadata=metadata
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == doc_id
|
||||
assert results[0].already_existed is False
|
||||
|
||||
@@ -5,6 +5,7 @@ These tests assume Vespa and OpenSearch are running.
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -166,3 +167,29 @@ class TestDocumentIndexOld:
|
||||
batch_retrieval=True,
|
||||
)
|
||||
assert len(inference_chunks) == 0
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndex],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk("test_doc_gen", chunk_id=i)
|
||||
|
||||
index_batch_params = IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
|
||||
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
|
||||
tenant_id=get_current_tenant_id(),
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
results = document_index.index(chunk_gen(), index_batch_params)
|
||||
|
||||
assert len(results) == 1
|
||||
record = results.pop()
|
||||
assert record.document_id == "test_doc_gen"
|
||||
assert record.already_existed is False
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in run_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
|
||||
"""Advance the generator one step to trigger early validation."""
|
||||
from onyx.chat.process_message import run_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = run_multi_model_stream(req, user, db, overrides)
|
||||
# Calling next() executes until the first yield OR raises.
|
||||
# Validation errors are raised before any yield.
|
||||
next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_raises(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
|
||||
def test_four_overrides_raises(self) -> None:
|
||||
"""4 overrides exceeds maximum — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_zero_overrides_raises(self) -> None:
|
||||
"""Empty override list raises."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [])
|
||||
|
||||
def test_deep_research_raises(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model."""
|
||||
req = _make_request(deep_research=True)
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
|
||||
req = _make_request()
|
||||
# 1 override must fail
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
|
||||
try:
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
except ValueError as exc:
|
||||
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
|
||||
except Exception:
|
||||
pass # Any other error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.query.return_value.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.query.return_value.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.query.return_value.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
@@ -1,134 +0,0 @@
|
||||
"""Unit tests for multi-model answer generation types.
|
||||
|
||||
Tests cover:
|
||||
- Placement.model_index serialization
|
||||
- MultiModelMessageResponseIDInfo round-trip
|
||||
- SendMessageRequest.llm_overrides backward compatibility
|
||||
- ChatMessageDetail new fields
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
|
||||
|
||||
class TestPlacementModelIndex:
|
||||
def test_default_none(self) -> None:
|
||||
p = Placement(turn_index=0)
|
||||
assert p.model_index is None
|
||||
|
||||
def test_set_value(self) -> None:
|
||||
p = Placement(turn_index=0, model_index=2)
|
||||
assert p.model_index == 2
|
||||
|
||||
def test_serializes(self) -> None:
|
||||
p = Placement(turn_index=0, tab_index=1, model_index=1)
|
||||
d = p.model_dump()
|
||||
assert d["model_index"] == 1
|
||||
|
||||
def test_none_excluded_when_default(self) -> None:
|
||||
p = Placement(turn_index=0)
|
||||
d = p.model_dump()
|
||||
assert d["model_index"] is None
|
||||
|
||||
|
||||
class TestMultiModelMessageResponseIDInfo:
|
||||
def test_round_trip(self) -> None:
|
||||
info = MultiModelMessageResponseIDInfo(
|
||||
user_message_id=42,
|
||||
reserved_assistant_message_ids=[43, 44, 45],
|
||||
model_names=["gpt-4", "claude-opus", "gemini-pro"],
|
||||
)
|
||||
d = info.model_dump()
|
||||
restored = MultiModelMessageResponseIDInfo(**d)
|
||||
assert restored.user_message_id == 42
|
||||
assert restored.reserved_assistant_message_ids == [43, 44, 45]
|
||||
assert restored.model_names == ["gpt-4", "claude-opus", "gemini-pro"]
|
||||
|
||||
def test_null_user_message_id(self) -> None:
|
||||
info = MultiModelMessageResponseIDInfo(
|
||||
user_message_id=None,
|
||||
reserved_assistant_message_ids=[1, 2],
|
||||
model_names=["a", "b"],
|
||||
)
|
||||
assert info.user_message_id is None
|
||||
|
||||
|
||||
class TestSendMessageRequestOverrides:
|
||||
def test_llm_overrides_default_none(self) -> None:
|
||||
req = SendMessageRequest(
|
||||
message="hello",
|
||||
chat_session_id=uuid4(),
|
||||
)
|
||||
assert req.llm_overrides is None
|
||||
|
||||
def test_llm_overrides_accepts_list(self) -> None:
|
||||
overrides = [
|
||||
LLMOverride(model_provider="openai", model_version="gpt-4"),
|
||||
LLMOverride(model_provider="anthropic", model_version="claude-opus"),
|
||||
]
|
||||
req = SendMessageRequest(
|
||||
message="hello",
|
||||
chat_session_id=uuid4(),
|
||||
llm_overrides=overrides,
|
||||
)
|
||||
assert req.llm_overrides is not None
|
||||
assert len(req.llm_overrides) == 2
|
||||
|
||||
def test_backward_compat_single_override(self) -> None:
|
||||
req = SendMessageRequest(
|
||||
message="hello",
|
||||
chat_session_id=uuid4(),
|
||||
llm_override=LLMOverride(model_provider="openai", model_version="gpt-4"),
|
||||
)
|
||||
assert req.llm_override is not None
|
||||
assert req.llm_overrides is None
|
||||
|
||||
|
||||
class TestChatMessageDetailMultiModel:
|
||||
def test_defaults_none(self) -> None:
|
||||
from onyx.configs.constants import MessageType
|
||||
|
||||
detail = ChatMessageDetail(
|
||||
message_id=1,
|
||||
message="hello",
|
||||
message_type=MessageType.ASSISTANT,
|
||||
time_sent="2026-03-22T00:00:00Z",
|
||||
files=[],
|
||||
)
|
||||
assert detail.preferred_response_id is None
|
||||
assert detail.model_display_name is None
|
||||
|
||||
def test_set_values(self) -> None:
|
||||
from onyx.configs.constants import MessageType
|
||||
|
||||
detail = ChatMessageDetail(
|
||||
message_id=1,
|
||||
message="hello",
|
||||
message_type=MessageType.USER,
|
||||
time_sent="2026-03-22T00:00:00Z",
|
||||
files=[],
|
||||
preferred_response_id=42,
|
||||
model_display_name="GPT-4",
|
||||
)
|
||||
assert detail.preferred_response_id == 42
|
||||
assert detail.model_display_name == "GPT-4"
|
||||
|
||||
def test_serializes(self) -> None:
|
||||
from onyx.configs.constants import MessageType
|
||||
|
||||
detail = ChatMessageDetail(
|
||||
message_id=1,
|
||||
message="hello",
|
||||
message_type=MessageType.ASSISTANT,
|
||||
time_sent="2026-03-22T00:00:00Z",
|
||||
files=[],
|
||||
model_display_name="Claude Opus",
|
||||
)
|
||||
d = detail.model_dump()
|
||||
assert d["model_display_name"] == "Claude Opus"
|
||||
assert d["preferred_response_id"] is None
|
||||
@@ -0,0 +1,226 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int,
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
"""Creates a minimal DocMetadataAwareIndexChunk for testing."""
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
sections=[TextSection(text="test", link="http://test.com")],
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier="test_doc",
|
||||
metadata={},
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links={0: "http://test.com"},
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings={"full_embedding": [0.1] * 10, "mini_chunk_embeddings": []},
|
||||
title_embedding=[0.1] * 10,
|
||||
tenant_id="test_tenant",
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
ancestor_hierarchy_node_ids=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_index() -> OpenSearchDocumentIndex:
|
||||
"""Creates an OpenSearchDocumentIndex with a mocked client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.bulk_index_documents = MagicMock()
|
||||
|
||||
tenant_state = TenantState(tenant_id="test_tenant", multitenant=False)
|
||||
|
||||
index = OpenSearchDocumentIndex.__new__(OpenSearchDocumentIndex)
|
||||
index._index_name = "test_index"
|
||||
index._client = mock_client
|
||||
index._tenant_state = tenant_state
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def _make_metadata(doc_id: str, chunk_count: int) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=0,
|
||||
new_chunk_cnt=chunk_count,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_under_batch_limit_flushes_once() -> None:
|
||||
"""A document with fewer chunks than MAX_CHUNKS_PER_DOC_BATCH should flush once."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 50
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
assert index._client.bulk_index_documents.call_count == 1
|
||||
batch_arg = index._client.bulk_index_documents.call_args_list[0]
|
||||
assert len(batch_arg.kwargs["documents"]) == num_chunks
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_over_batch_limit_flushes_multiple_times() -> None:
|
||||
"""A document with more chunks than MAX_CHUNKS_PER_DOC_BATCH should flush multiple times."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 250
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 250 chunks / 100 per batch = 3 flushes (100 + 100 + 50)
|
||||
assert index._client.bulk_index_documents.call_count == 3
|
||||
batch_sizes = [
|
||||
len(call.kwargs["documents"])
|
||||
for call in index._client.bulk_index_documents.call_args_list
|
||||
]
|
||||
assert batch_sizes == [100, 100, 50]
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_exactly_at_batch_limit() -> None:
|
||||
"""A document with exactly MAX_CHUNKS_PER_DOC_BATCH chunks should flush once
|
||||
(the flush happens on the next chunk, not at the boundary)."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 100
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 100 chunks hit the >= check on chunk 101 which doesn't exist,
|
||||
# so final flush handles all 100
|
||||
# Actually: the elif fires when len(current_chunks) >= 100, which happens
|
||||
# when current_chunks has 100 items and the 101st chunk arrives.
|
||||
# With exactly 100 chunks, the 100th chunk makes len == 99, then appended -> 100.
|
||||
# No 101st chunk arrives, so the final flush handles all 100.
|
||||
assert index._client.bulk_index_documents.call_count == 1
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_one_over_batch_limit() -> None:
|
||||
"""101 chunks for one doc: first 100 flushed when the 101st arrives, then
|
||||
the 101st is flushed at the end."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 101
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
assert index._client.bulk_index_documents.call_count == 2
|
||||
batch_sizes = [
|
||||
len(call.kwargs["documents"])
|
||||
for call in index._client.bulk_index_documents.call_args_list
|
||||
]
|
||||
assert batch_sizes == [100, 1]
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_multiple_docs_each_under_limit_flush_per_doc() -> None:
|
||||
"""Multiple documents each under the batch limit should flush once per document."""
|
||||
index = _make_index()
|
||||
chunks = []
|
||||
for doc_idx in range(3):
|
||||
doc_id = f"doc_{doc_idx}"
|
||||
for chunk_idx in range(50):
|
||||
chunks.append(_make_chunk(doc_id, chunk_idx))
|
||||
|
||||
metadata = IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
f"doc_{i}": IndexingMetadata.ChunkCounts(old_chunk_cnt=0, new_chunk_cnt=50)
|
||||
for i in range(3)
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 3 documents = 3 flushes (one per doc boundary + final)
|
||||
assert index._client.bulk_index_documents.call_count == 3
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_delete_called_once_per_document() -> None:
|
||||
"""Even with multiple flushes for a single document, delete should only be
|
||||
called once per document."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 250
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0) as mock_delete:
|
||||
index.index(chunks, metadata)
|
||||
|
||||
mock_delete.assert_called_once_with(doc_id, None)
|
||||
@@ -0,0 +1,152 @@
|
||||
"""Unit tests for VespaDocumentIndex.index().
|
||||
|
||||
These tests mock all external I/O (HTTP calls, thread pools) and verify
|
||||
the streaming logic, ID cleaning/mapping, and DocumentInsertionRecord
|
||||
construction.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int = 0,
|
||||
content: str = "test content",
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_doc",
|
||||
sections=[TextSection(text=content, link=None)],
|
||||
source=DocumentSource.NOT_APPLICABLE,
|
||||
metadata={},
|
||||
)
|
||||
index_chunk = IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb=content[:50],
|
||||
content=content,
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.1] * 10,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=index_chunk,
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
|
||||
def _make_indexing_metadata(
|
||||
doc_ids: list[str],
|
||||
old_counts: list[int],
|
||||
new_counts: list[int],
|
||||
) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=old,
|
||||
new_chunk_cnt=new,
|
||||
)
|
||||
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _stub_enrich(
|
||||
doc_id: str,
|
||||
old_chunk_cnt: int,
|
||||
) -> EnrichedDocumentIndexingInfo:
|
||||
"""Build an EnrichedDocumentIndexingInfo that says 'no chunks to delete'
|
||||
when old_chunk_cnt == 0, or 'has existing chunks' otherwise."""
|
||||
return EnrichedDocumentIndexingInfo(
|
||||
doc_id=doc_id,
|
||||
chunk_start_index=0,
|
||||
old_version=False,
|
||||
chunk_end_index=old_chunk_cnt,
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.BATCH_SIZE",
|
||||
3,
|
||||
)
|
||||
def test_index_respects_batch_size(
|
||||
mock_enrich: MagicMock,
|
||||
mock_get_chunk_ids: MagicMock, # noqa: ARG001
|
||||
mock_delete: MagicMock, # noqa: ARG001
|
||||
mock_batch_index: MagicMock,
|
||||
) -> None:
|
||||
"""When chunks exceed BATCH_SIZE, batch_index_vespa_chunks is called
|
||||
multiple times with correctly sized batches."""
|
||||
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
|
||||
|
||||
index = VespaDocumentIndex(
|
||||
index_name="test_index",
|
||||
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=MagicMock(),
|
||||
)
|
||||
|
||||
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(7)]
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[7])
|
||||
|
||||
results = index.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
# With BATCH_SIZE=3 and 7 chunks: batches of 3, 3, 1
|
||||
assert mock_batch_index.call_count == 3
|
||||
batch_sizes = [len(c.kwargs["chunks"]) for c in mock_batch_index.call_args_list]
|
||||
assert batch_sizes == [3, 3, 1]
|
||||
|
||||
# Verify all chunks are accounted for and in order
|
||||
all_indexed = [
|
||||
chunk for c in mock_batch_index.call_args_list for chunk in c.kwargs["chunks"]
|
||||
]
|
||||
assert len(all_indexed) == 7
|
||||
assert [c.chunk_id for c in all_indexed] == list(range(7))
|
||||
@@ -11,22 +11,15 @@ import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
|
||||
import { useState, useEffect } from "react";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import FrostedDiv from "@/refresh-components/FrostedDiv";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface WelcomeMessageProps {
|
||||
agent?: MinimalPersonaSnapshot;
|
||||
isDefaultAgent: boolean;
|
||||
/** Optional right-aligned element rendered on the same row as the greeting (e.g. model selector). */
|
||||
rightChildren?: React.ReactNode;
|
||||
/** When true, the greeting/logo content is hidden (but space is preserved). Used at max models. */
|
||||
hideTitle?: boolean;
|
||||
}
|
||||
|
||||
export default function WelcomeMessage({
|
||||
agent,
|
||||
isDefaultAgent,
|
||||
rightChildren,
|
||||
hideTitle,
|
||||
}: WelcomeMessageProps) {
|
||||
const settings = useSettingsContext();
|
||||
const enterpriseSettings = settings?.enterpriseSettings;
|
||||
@@ -46,10 +39,8 @@ export default function WelcomeMessage({
|
||||
|
||||
if (isDefaultAgent) {
|
||||
content = (
|
||||
<div data-testid="onyx-logo" className="flex flex-col items-start gap-2">
|
||||
<div className="flex items-center justify-center size-9 p-0.5">
|
||||
<Logo folded size={32} />
|
||||
</div>
|
||||
<div data-testid="onyx-logo" className="flex flex-row items-center gap-4">
|
||||
<Logo folded size={32} />
|
||||
<Text as="p" headingH2>
|
||||
{greeting}
|
||||
</Text>
|
||||
@@ -57,15 +48,17 @@ export default function WelcomeMessage({
|
||||
);
|
||||
} else if (agent) {
|
||||
content = (
|
||||
<div
|
||||
data-testid="agent-name-display"
|
||||
className="flex flex-col items-start gap-2"
|
||||
>
|
||||
<AgentAvatar agent={agent} size={36} />
|
||||
<Text as="p" headingH2>
|
||||
{agent.name}
|
||||
</Text>
|
||||
</div>
|
||||
<>
|
||||
<div
|
||||
data-testid="agent-name-display"
|
||||
className="flex flex-row items-center gap-3"
|
||||
>
|
||||
<AgentAvatar agent={agent} size={36} />
|
||||
<Text as="p" headingH2>
|
||||
{agent.name}
|
||||
</Text>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -76,24 +69,9 @@ export default function WelcomeMessage({
|
||||
return (
|
||||
<FrostedDiv
|
||||
data-testid="chat-intro"
|
||||
wrapperClassName="w-full"
|
||||
className="flex flex-col items-center justify-center gap-3 w-full max-w-[var(--app-page-main-content-width)] mx-auto"
|
||||
className="flex flex-col items-center justify-center gap-3 w-full max-w-[var(--app-page-main-content-width)]"
|
||||
>
|
||||
{rightChildren ? (
|
||||
<div className="flex items-end gap-2 w-full">
|
||||
<div
|
||||
className={cn(
|
||||
"flex-1 min-w-0 min-h-[80px] px-2 py-1",
|
||||
hideTitle && "invisible"
|
||||
)}
|
||||
>
|
||||
{content}
|
||||
</div>
|
||||
<div className="shrink-0">{rightChildren}</div>
|
||||
</div>
|
||||
) : (
|
||||
content
|
||||
)}
|
||||
{content}
|
||||
</FrostedDiv>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -159,10 +159,6 @@ export interface Message {
|
||||
overridden_model?: string;
|
||||
stopReason?: StreamStopReason | null;
|
||||
|
||||
// Multi-model answer generation
|
||||
preferredResponseId?: number | null;
|
||||
modelDisplayName?: string | null;
|
||||
|
||||
// new gen
|
||||
packets: Packet[];
|
||||
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
|
||||
@@ -235,9 +231,6 @@ export interface BackendMessage {
|
||||
parentMessageId: number | null;
|
||||
refined_answer_improvement: boolean | null;
|
||||
is_agentic: boolean | null;
|
||||
// Multi-model answer generation
|
||||
preferred_response_id: number | null;
|
||||
model_display_name: string | null;
|
||||
}
|
||||
|
||||
export interface MessageResponseIDInfo {
|
||||
@@ -245,12 +238,6 @@ export interface MessageResponseIDInfo {
|
||||
reserved_assistant_message_id: number; // TODO: rename to agent — https://linear.app/onyx-app/issue/ENG-3766
|
||||
}
|
||||
|
||||
export interface MultiModelMessageResponseIDInfo {
|
||||
user_message_id: number | null;
|
||||
reserved_assistant_message_ids: number[];
|
||||
model_names: string[];
|
||||
}
|
||||
|
||||
export interface UserKnowledgeFilePacket {
|
||||
user_files: FileDescriptor[];
|
||||
}
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback } from "react";
|
||||
import { Button } from "@opal/components";
|
||||
import { Hoverable } from "@opal/core";
|
||||
import { SvgEyeClosed, SvgMoreHorizontal, SvgX } from "@opal/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
|
||||
import AgentMessage, {
|
||||
AgentMessageProps,
|
||||
} from "@/app/app/message/messageComponents/AgentMessage";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface MultiModelPanelProps {
|
||||
/** Index of this model in the selectedModels array (used for Hoverable group key) */
|
||||
modelIndex: number;
|
||||
/** Provider name for icon lookup */
|
||||
provider: string;
|
||||
/** Model name for icon lookup and display */
|
||||
modelName: string;
|
||||
/** Display-friendly model name */
|
||||
displayName: string;
|
||||
/** Whether this panel is the preferred/selected response */
|
||||
isPreferred: boolean;
|
||||
/** Whether this panel is currently hidden */
|
||||
isHidden: boolean;
|
||||
/** Whether this is a non-preferred panel in selection mode (pushed off-screen) */
|
||||
isNonPreferredInSelection: boolean;
|
||||
/** Callback when user clicks this panel to select as preferred */
|
||||
onSelect: () => void;
|
||||
/** Callback to hide/show this panel */
|
||||
onToggleVisibility: () => void;
|
||||
/** Props to pass through to AgentMessage */
|
||||
agentMessageProps: AgentMessageProps;
|
||||
}
|
||||
|
||||
export default function MultiModelPanel({
|
||||
modelIndex,
|
||||
provider,
|
||||
modelName,
|
||||
displayName,
|
||||
isPreferred,
|
||||
isHidden,
|
||||
isNonPreferredInSelection,
|
||||
onSelect,
|
||||
onToggleVisibility,
|
||||
agentMessageProps,
|
||||
}: MultiModelPanelProps) {
|
||||
const ProviderIcon = getProviderIcon(provider, modelName);
|
||||
|
||||
const handlePanelClick = useCallback(() => {
|
||||
if (!isHidden) {
|
||||
onSelect();
|
||||
}
|
||||
}, [isHidden, onSelect]);
|
||||
|
||||
// Hidden/collapsed panel — compact strip: icon + strikethrough name + eye icon
|
||||
if (isHidden) {
|
||||
return (
|
||||
<div className="flex items-center gap-1.5 rounded-08 bg-background-tint-00 px-2 py-1 opacity-50">
|
||||
<div className="flex items-center justify-center size-5 shrink-0">
|
||||
<ProviderIcon size={16} />
|
||||
</div>
|
||||
<Text secondaryBody text02 nowrap className="line-through">
|
||||
{displayName}
|
||||
</Text>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgEyeClosed}
|
||||
size="2xs"
|
||||
onClick={onToggleVisibility}
|
||||
tooltip="Show response"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const hoverGroup = `panel-${modelIndex}`;
|
||||
|
||||
return (
|
||||
<Hoverable.Root group={hoverGroup}>
|
||||
<div
|
||||
className="flex flex-col min-w-0 gap-3 cursor-pointer"
|
||||
onClick={handlePanelClick}
|
||||
>
|
||||
{/* Panel header */}
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center gap-1.5 rounded-12 px-2 py-1",
|
||||
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center justify-center size-5 shrink-0">
|
||||
<ProviderIcon size={16} />
|
||||
</div>
|
||||
<Text mainUiAction text04 nowrap className="flex-1 min-w-0 truncate">
|
||||
{displayName}
|
||||
</Text>
|
||||
{isPreferred && (
|
||||
<Text secondaryBody nowrap className="text-action-link-05 shrink-0">
|
||||
Preferred Response
|
||||
</Text>
|
||||
)}
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgMoreHorizontal}
|
||||
size="2xs"
|
||||
tooltip="More"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
/>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgX}
|
||||
size="2xs"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onToggleVisibility();
|
||||
}}
|
||||
tooltip="Hide response"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* "Select This Response" hover affordance */}
|
||||
{!isPreferred && !isNonPreferredInSelection && (
|
||||
<Hoverable.Item group={hoverGroup} variant="opacity-on-hover">
|
||||
<div className="flex justify-center pointer-events-none">
|
||||
<div className="flex items-center h-6 bg-background-tint-00 rounded-08 px-1 shadow-sm">
|
||||
<Text
|
||||
secondaryBody
|
||||
className="font-semibold text-text-03 px-1 whitespace-nowrap"
|
||||
>
|
||||
Select This Response
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</Hoverable.Item>
|
||||
)}
|
||||
|
||||
{/* Response body */}
|
||||
<div className={cn(isNonPreferredInSelection && "pointer-events-none")}>
|
||||
<AgentMessage
|
||||
{...agentMessageProps}
|
||||
hideFooter={isNonPreferredInSelection}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Hoverable.Root>
|
||||
);
|
||||
}
|
||||
@@ -1,229 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useCallback, useMemo } from "react";
|
||||
import { Packet } from "@/app/app/services/streamingModels";
|
||||
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
|
||||
import { FeedbackType, Message } from "@/app/app/interfaces";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { RegenerationFactory } from "@/app/app/message/messageComponents/AgentMessage";
|
||||
import MultiModelPanel from "@/app/app/message/MultiModelPanel";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface MultiModelResponse {
|
||||
modelIndex: number;
|
||||
provider: string;
|
||||
modelName: string;
|
||||
displayName: string;
|
||||
packets: Packet[];
|
||||
packetCount: number;
|
||||
nodeId: number;
|
||||
messageId?: number;
|
||||
isHighlighted?: boolean;
|
||||
currentFeedback?: FeedbackType | null;
|
||||
isGenerating?: boolean;
|
||||
}
|
||||
|
||||
export interface MultiModelResponseViewProps {
|
||||
responses: MultiModelResponse[];
|
||||
chatState: FullChatState;
|
||||
llmManager: LlmManager | null;
|
||||
onRegenerate?: RegenerationFactory;
|
||||
parentMessage?: Message | null;
|
||||
otherMessagesCanSwitchTo?: number[];
|
||||
onMessageSelection?: (nodeId: number) => void;
|
||||
}
|
||||
|
||||
export default function MultiModelResponseView({
|
||||
responses,
|
||||
chatState,
|
||||
llmManager,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
}: MultiModelResponseViewProps) {
|
||||
const [preferredIndex, setPreferredIndex] = useState<number | null>(null);
|
||||
const [hiddenPanels, setHiddenPanels] = useState<Set<number>>(new Set());
|
||||
|
||||
const isGenerating = useMemo(
|
||||
() => responses.some((r) => r.isGenerating),
|
||||
[responses]
|
||||
);
|
||||
|
||||
const visibleResponses = useMemo(
|
||||
() => responses.filter((r) => !hiddenPanels.has(r.modelIndex)),
|
||||
[responses, hiddenPanels]
|
||||
);
|
||||
|
||||
const hiddenResponses = useMemo(
|
||||
() => responses.filter((r) => hiddenPanels.has(r.modelIndex)),
|
||||
[responses, hiddenPanels]
|
||||
);
|
||||
|
||||
const toggleVisibility = useCallback(
|
||||
(modelIndex: number) => {
|
||||
setHiddenPanels((prev) => {
|
||||
const next = new Set(prev);
|
||||
if (next.has(modelIndex)) {
|
||||
next.delete(modelIndex);
|
||||
} else {
|
||||
// Don't hide the last visible panel
|
||||
const visibleCount = responses.length - next.size;
|
||||
if (visibleCount <= 1) return prev;
|
||||
next.add(modelIndex);
|
||||
}
|
||||
return next;
|
||||
});
|
||||
},
|
||||
[responses.length]
|
||||
);
|
||||
|
||||
const handleSelectPreferred = useCallback(
|
||||
(modelIndex: number) => {
|
||||
setPreferredIndex(modelIndex);
|
||||
const response = responses[modelIndex];
|
||||
if (!response) return;
|
||||
// Sync with message tree — mark this response as the latest child
|
||||
// so the next message chains from it.
|
||||
if (onMessageSelection) {
|
||||
onMessageSelection(response.nodeId);
|
||||
}
|
||||
},
|
||||
[responses, onMessageSelection]
|
||||
);
|
||||
|
||||
// Selection mode when preferred is set and not generating
|
||||
const showSelectionMode =
|
||||
preferredIndex !== null && !isGenerating && visibleResponses.length > 1;
|
||||
|
||||
// Build common panel props
|
||||
const buildPanelProps = useCallback(
|
||||
(response: MultiModelResponse, isNonPreferred: boolean) => ({
|
||||
modelIndex: response.modelIndex,
|
||||
provider: response.provider,
|
||||
modelName: response.modelName,
|
||||
displayName: response.displayName,
|
||||
isPreferred: preferredIndex === response.modelIndex,
|
||||
isHidden: false as const,
|
||||
isNonPreferredInSelection: isNonPreferred,
|
||||
onSelect: () => handleSelectPreferred(response.modelIndex),
|
||||
onToggleVisibility: () => toggleVisibility(response.modelIndex),
|
||||
agentMessageProps: {
|
||||
rawPackets: response.packets,
|
||||
packetCount: response.packetCount,
|
||||
chatState,
|
||||
nodeId: response.nodeId,
|
||||
messageId: response.messageId,
|
||||
currentFeedback: response.currentFeedback,
|
||||
llmManager,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
},
|
||||
}),
|
||||
[
|
||||
preferredIndex,
|
||||
handleSelectPreferred,
|
||||
toggleVisibility,
|
||||
chatState,
|
||||
llmManager,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
]
|
||||
);
|
||||
|
||||
// Shared renderer for hidden panels (inline in the flex row)
|
||||
const renderHiddenPanels = () =>
|
||||
hiddenResponses.map((r) => (
|
||||
<div key={r.modelIndex} className="w-[240px] shrink-0">
|
||||
<MultiModelPanel
|
||||
modelIndex={r.modelIndex}
|
||||
provider={r.provider}
|
||||
modelName={r.modelName}
|
||||
displayName={r.displayName}
|
||||
isPreferred={false}
|
||||
isHidden
|
||||
isNonPreferredInSelection={false}
|
||||
onSelect={() => handleSelectPreferred(r.modelIndex)}
|
||||
onToggleVisibility={() => toggleVisibility(r.modelIndex)}
|
||||
agentMessageProps={buildPanelProps(r, false).agentMessageProps}
|
||||
/>
|
||||
</div>
|
||||
));
|
||||
|
||||
if (showSelectionMode) {
|
||||
// ── Selection Layout ──
|
||||
// Preferred stays at normal chat width, centered.
|
||||
// Non-preferred panels are pushed to the viewport edges and clip off-screen.
|
||||
const preferredIdx = visibleResponses.findIndex(
|
||||
(r) => r.modelIndex === preferredIndex
|
||||
);
|
||||
const preferred = visibleResponses[preferredIdx];
|
||||
const leftPanels = visibleResponses.slice(0, preferredIdx);
|
||||
const rightPanels = visibleResponses.slice(preferredIdx + 1);
|
||||
|
||||
// Non-preferred panel width and gap between panels
|
||||
const PANEL_W = 400;
|
||||
const GAP = 16;
|
||||
|
||||
return (
|
||||
<div className="w-full relative overflow-hidden">
|
||||
{/* Preferred — centered at normal chat width, in flow to set container height */}
|
||||
{preferred && (
|
||||
<div className="w-full max-w-[720px] min-w-[400px] mx-auto">
|
||||
<MultiModelPanel {...buildPanelProps(preferred, false)} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Non-preferred on the left — anchored to the left of the preferred panel */}
|
||||
{leftPanels.map((r, i) => (
|
||||
<div
|
||||
key={r.modelIndex}
|
||||
className="absolute top-0"
|
||||
style={{
|
||||
width: `${PANEL_W}px`,
|
||||
// Right edge of this panel sits just left of the preferred panel
|
||||
right: `calc(50% + var(--app-page-main-content-width) / 2 + ${
|
||||
GAP + i * (PANEL_W + GAP)
|
||||
}px)`,
|
||||
}}
|
||||
>
|
||||
<MultiModelPanel {...buildPanelProps(r, true)} />
|
||||
</div>
|
||||
))}
|
||||
|
||||
{/* Non-preferred on the right — anchored to the right of the preferred panel */}
|
||||
{rightPanels.map((r, i) => (
|
||||
<div
|
||||
key={r.modelIndex}
|
||||
className="absolute top-0"
|
||||
style={{
|
||||
width: `${PANEL_W}px`,
|
||||
// Left edge of this panel sits just right of the preferred panel
|
||||
left: `calc(50% + var(--app-page-main-content-width) / 2 + ${
|
||||
GAP + i * (PANEL_W + GAP)
|
||||
}px)`,
|
||||
}}
|
||||
>
|
||||
<MultiModelPanel {...buildPanelProps(r, true)} />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ── Generation Layout (equal panels) ──
|
||||
return (
|
||||
<div className="flex gap-6 items-start justify-center">
|
||||
{visibleResponses.map((r) => (
|
||||
<div key={r.modelIndex} className="flex-1 min-w-[400px] max-w-[720px]">
|
||||
<MultiModelPanel {...buildPanelProps(r, false)} />
|
||||
</div>
|
||||
))}
|
||||
{renderHiddenPanels()}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -49,8 +49,6 @@ export interface AgentMessageProps {
|
||||
parentMessage?: Message | null;
|
||||
// Duration in seconds for processing this message (agent messages only)
|
||||
processingDurationSeconds?: number;
|
||||
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
|
||||
hideFooter?: boolean;
|
||||
}
|
||||
|
||||
// TODO: Consider more robust comparisons:
|
||||
@@ -78,8 +76,7 @@ function arePropsEqual(
|
||||
prev.parentMessage?.messageId === next.parentMessage?.messageId &&
|
||||
prev.llmManager?.isLoadingProviders ===
|
||||
next.llmManager?.isLoadingProviders &&
|
||||
prev.processingDurationSeconds === next.processingDurationSeconds &&
|
||||
prev.hideFooter === next.hideFooter
|
||||
prev.processingDurationSeconds === next.processingDurationSeconds
|
||||
// Skip: chatState.regenerate, chatState.setPresentingDocument,
|
||||
// most of llmManager, onMessageSelection (function/object props)
|
||||
);
|
||||
@@ -98,7 +95,6 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
processingDurationSeconds,
|
||||
hideFooter,
|
||||
}: AgentMessageProps) {
|
||||
const markdownRef = useRef<HTMLDivElement>(null);
|
||||
const finalAnswerRef = useRef<HTMLDivElement>(null);
|
||||
@@ -330,7 +326,7 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
</div>
|
||||
|
||||
{/* Feedback buttons - only show when streaming and rendering complete */}
|
||||
{isComplete && !hideFooter && (
|
||||
{isComplete && (
|
||||
<MessageToolbar
|
||||
nodeId={nodeId}
|
||||
messageId={messageId}
|
||||
|
||||
@@ -12,7 +12,6 @@ import {
|
||||
FileChatDisplay,
|
||||
Message,
|
||||
MessageResponseIDInfo,
|
||||
MultiModelMessageResponseIDInfo,
|
||||
ResearchType,
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
@@ -97,7 +96,6 @@ export type PacketType =
|
||||
| FileChatDisplay
|
||||
| StreamingError
|
||||
| MessageResponseIDInfo
|
||||
| MultiModelMessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| UserKnowledgeFilePacket
|
||||
| Packet;
|
||||
@@ -111,13 +109,6 @@ export type MessageOrigin =
|
||||
| "slackbot"
|
||||
| "unknown";
|
||||
|
||||
export interface LLMOverride {
|
||||
model_provider: string;
|
||||
model_version: string;
|
||||
temperature?: number;
|
||||
display_name?: string;
|
||||
}
|
||||
|
||||
export interface SendMessageParams {
|
||||
message: string;
|
||||
fileDescriptors?: FileDescriptor[];
|
||||
@@ -133,8 +124,6 @@ export interface SendMessageParams {
|
||||
modelProvider?: string;
|
||||
modelVersion?: string;
|
||||
temperature?: number;
|
||||
// Multi-model: send multiple LLM overrides for parallel generation
|
||||
llmOverrides?: LLMOverride[];
|
||||
// Origin of the message for telemetry tracking
|
||||
origin?: MessageOrigin;
|
||||
// Additional context injected into the LLM call but not stored/shown in chat.
|
||||
@@ -155,7 +144,6 @@ export async function* sendMessage({
|
||||
modelProvider,
|
||||
modelVersion,
|
||||
temperature,
|
||||
llmOverrides,
|
||||
origin,
|
||||
additionalContext,
|
||||
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
|
||||
@@ -177,8 +165,6 @@ export async function* sendMessage({
|
||||
model_version: modelVersion,
|
||||
}
|
||||
: null,
|
||||
// Multi-model: list of LLM overrides for parallel generation
|
||||
llm_overrides: llmOverrides ?? null,
|
||||
// Default to "unknown" for consistency with backend; callers should set explicitly
|
||||
origin: origin ?? "unknown",
|
||||
additional_context: additionalContext ?? null,
|
||||
@@ -202,20 +188,6 @@ export async function* sendMessage({
|
||||
yield* handleSSEStream<PacketType>(response, signal);
|
||||
}
|
||||
|
||||
export async function setPreferredResponse(
|
||||
userMessageId: number,
|
||||
preferredResponseId: number
|
||||
): Promise<Response> {
|
||||
return fetch("/api/chat/set-preferred-response", {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
user_message_id: userMessageId,
|
||||
preferred_response_id: preferredResponseId,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
export async function nameChatSession(chatSessionId: string) {
|
||||
const response = await fetch("/api/chat/rename-chat-session", {
|
||||
method: "PUT",
|
||||
@@ -385,9 +357,6 @@ export function processRawChatHistory(
|
||||
overridden_model: messageInfo.overridden_model,
|
||||
packets: packetsForMessage || [],
|
||||
currentFeedback: messageInfo.current_feedback as FeedbackType | null,
|
||||
// Multi-model answer generation
|
||||
preferredResponseId: messageInfo.preferred_response_id ?? null,
|
||||
modelDisplayName: messageInfo.model_display_name ?? null,
|
||||
};
|
||||
|
||||
messages.set(messageInfo.message_id, message);
|
||||
|
||||
@@ -403,7 +403,6 @@ export interface Placement {
|
||||
turn_index: number;
|
||||
tab_index?: number; // For parallel tool calls - tools with same turn_index but different tab_index run in parallel
|
||||
sub_turn_index?: number | null;
|
||||
model_index?: number | null; // For multi-model answer generation - identifies which model produced this packet
|
||||
}
|
||||
|
||||
// Packet wrapper for streaming objects
|
||||
|
||||
@@ -459,7 +459,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
onResubmit={handleResubmitLastMessage}
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
anchorNodeId={anchorNodeId}
|
||||
selectedModels={[]}
|
||||
/>
|
||||
</ChatScrollContainer>
|
||||
</>
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import {
|
||||
buildChatUrl,
|
||||
getAvailableContextTokens,
|
||||
LLMOverride,
|
||||
nameChatSession,
|
||||
updateLlmOverrideForChatSession,
|
||||
} from "@/app/app/services/lib";
|
||||
@@ -34,7 +33,6 @@ import {
|
||||
FileDescriptor,
|
||||
Message,
|
||||
MessageResponseIDInfo,
|
||||
MultiModelMessageResponseIDInfo,
|
||||
RegenerationState,
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
@@ -72,7 +70,6 @@ import {
|
||||
} from "@/app/app/stores/useChatSessionStore";
|
||||
import { Packet, MessageStart } from "@/app/app/services/streamingModels";
|
||||
import useAgentPreferences from "@/hooks/useAgentPreferences";
|
||||
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
|
||||
import { useForcedTools } from "@/lib/hooks/useForcedTools";
|
||||
import { ProjectFile, useProjectsContext } from "@/providers/ProjectsContext";
|
||||
import { useAppParams } from "@/hooks/appNavigation";
|
||||
@@ -97,8 +94,6 @@ export interface OnSubmitProps {
|
||||
regenerationRequest?: RegenerationRequest | null;
|
||||
// Additional context injected into the LLM call but not stored/shown in chat.
|
||||
additionalContext?: string;
|
||||
// Multi-model chat: up to 3 models selected for parallel comparison.
|
||||
selectedModels?: SelectedModel[];
|
||||
}
|
||||
|
||||
interface RegenerationRequest {
|
||||
@@ -375,10 +370,7 @@ export default function useChatController({
|
||||
modelOverride,
|
||||
regenerationRequest,
|
||||
additionalContext,
|
||||
selectedModels,
|
||||
}: OnSubmitProps) => {
|
||||
// Check if this is multi-model mode (2 or 3 models selected)
|
||||
const isMultiModelMode = selectedModels && selectedModels.length >= 2;
|
||||
const projectId = params(SEARCH_PARAM_NAMES.PROJECT_ID);
|
||||
{
|
||||
const params = new URLSearchParams(searchParams?.toString() || "");
|
||||
@@ -609,7 +601,6 @@ export default function useChatController({
|
||||
// immediately reflects the user message
|
||||
let initialUserNode: Message;
|
||||
let initialAgentNode: Message;
|
||||
let initialAssistantNodes: Message[] = [];
|
||||
|
||||
if (regenerationRequest) {
|
||||
// For regeneration: keep the existing user message, only create new agent
|
||||
@@ -632,30 +623,12 @@ export default function useChatController({
|
||||
);
|
||||
initialUserNode = result.initialUserNode;
|
||||
initialAgentNode = result.initialAgentNode;
|
||||
|
||||
// In multi-model mode, create N assistant nodes (one per selected model)
|
||||
if (isMultiModelMode && selectedModels) {
|
||||
for (let i = 0; i < selectedModels.length; i++) {
|
||||
initialAssistantNodes.push(
|
||||
buildEmptyMessage({
|
||||
messageType: "assistant",
|
||||
parentNodeId: initialUserNode.nodeId,
|
||||
nodeIdOffset: i + 1,
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// make messages appear + clear input bar
|
||||
let messagesToUpsert: Message[];
|
||||
if (regenerationRequest) {
|
||||
messagesToUpsert = [initialAgentNode];
|
||||
} else if (isMultiModelMode) {
|
||||
messagesToUpsert = [initialUserNode, ...initialAssistantNodes];
|
||||
} else {
|
||||
messagesToUpsert = [initialUserNode, initialAgentNode];
|
||||
}
|
||||
const messagesToUpsert = regenerationRequest
|
||||
? [initialAgentNode] // Only upsert the new agent for regeneration
|
||||
: [initialUserNode, initialAgentNode]; // Upsert both for normal/edit flow
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: messagesToUpsert,
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
@@ -689,24 +662,6 @@ export default function useChatController({
|
||||
let newUserMessageId: number | null = null;
|
||||
let newAgentMessageId: number | null = null;
|
||||
|
||||
// Multi-model mode state tracking (dynamically sized based on selected models)
|
||||
const numModels = selectedModels?.length ?? 0;
|
||||
let newAssistantMessageIds: (number | null)[] = isMultiModelMode
|
||||
? Array(numModels).fill(null)
|
||||
: [];
|
||||
let packetsPerModel: Packet[][] = isMultiModelMode
|
||||
? Array.from({ length: numModels }, () => [])
|
||||
: [];
|
||||
let modelDisplayNames: string[] = isMultiModelMode
|
||||
? selectedModels?.map((m) => m.displayName) ?? []
|
||||
: [];
|
||||
let documentsPerModel: OnyxDocument[][] = isMultiModelMode
|
||||
? Array.from({ length: numModels }, () => [])
|
||||
: [];
|
||||
let citationsPerModel: (CitationMap | null)[] = isMultiModelMode
|
||||
? Array(numModels).fill(null)
|
||||
: [];
|
||||
|
||||
try {
|
||||
const lastSuccessfulMessageId = getLastSuccessfulMessageId(
|
||||
currentMessageTreeLocal
|
||||
@@ -755,15 +710,13 @@ export default function useChatController({
|
||||
filterManager.timeRange,
|
||||
filterManager.selectedTags
|
||||
),
|
||||
modelProvider: isMultiModelMode
|
||||
? undefined
|
||||
: modelOverride?.name || llmManager.currentLlm.name || undefined,
|
||||
modelVersion: isMultiModelMode
|
||||
? undefined
|
||||
: modelOverride?.modelName ||
|
||||
llmManager.currentLlm.modelName ||
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
|
||||
undefined,
|
||||
modelProvider:
|
||||
modelOverride?.name || llmManager.currentLlm.name || undefined,
|
||||
modelVersion:
|
||||
modelOverride?.modelName ||
|
||||
llmManager.currentLlm.modelName ||
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
|
||||
undefined,
|
||||
temperature: llmManager.temperature || undefined,
|
||||
deepResearch,
|
||||
enabledToolIds:
|
||||
@@ -775,12 +728,6 @@ export default function useChatController({
|
||||
forcedToolId: effectiveForcedToolId,
|
||||
origin: messageOrigin,
|
||||
additionalContext,
|
||||
llmOverrides: isMultiModelMode
|
||||
? selectedModels!.map((model) => ({
|
||||
model_provider: model.name,
|
||||
model_version: model.modelName,
|
||||
}))
|
||||
: undefined,
|
||||
});
|
||||
|
||||
const delay = (ms: number) => {
|
||||
@@ -833,26 +780,6 @@ export default function useChatController({
|
||||
.reserved_assistant_message_id;
|
||||
}
|
||||
|
||||
// Multi-model: handle reserved IDs for N parallel model responses
|
||||
if (
|
||||
isMultiModelMode &&
|
||||
Object.hasOwn(packet, "reserved_assistant_message_ids") &&
|
||||
Array.isArray(
|
||||
(packet as MultiModelMessageResponseIDInfo)
|
||||
.reserved_assistant_message_ids
|
||||
)
|
||||
) {
|
||||
const multiPacket = packet as MultiModelMessageResponseIDInfo;
|
||||
newAssistantMessageIds =
|
||||
multiPacket.reserved_assistant_message_ids;
|
||||
newUserMessageId =
|
||||
multiPacket.user_message_id ?? newUserMessageId;
|
||||
// Capture backend model names for display on reload
|
||||
if (multiPacket.model_names?.length) {
|
||||
modelDisplayNames = multiPacket.model_names;
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.hasOwn(packet, "user_files")) {
|
||||
const userFiles = (packet as UserKnowledgeFilePacket).user_files;
|
||||
// Ensure files are unique by id
|
||||
@@ -896,73 +823,32 @@ export default function useChatController({
|
||||
updateCanContinue(true, frozenSessionId);
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "obj")) {
|
||||
const typedPacket = packet as Packet;
|
||||
packets.push(packet as Packet);
|
||||
packetsVersion++;
|
||||
|
||||
// In multi-model mode, route packets by model_index
|
||||
if (isMultiModelMode) {
|
||||
const modelIndex = typedPacket.placement?.model_index ?? 0;
|
||||
if (
|
||||
modelIndex >= 0 &&
|
||||
modelIndex < packetsPerModel.length &&
|
||||
packetsPerModel[modelIndex]
|
||||
) {
|
||||
packetsPerModel[modelIndex] = [
|
||||
...packetsPerModel[modelIndex]!,
|
||||
typedPacket,
|
||||
];
|
||||
// Check if the packet contains document information
|
||||
const packetObj = (packet as Packet).obj;
|
||||
|
||||
const packetObj = typedPacket.obj;
|
||||
|
||||
if (packetObj.type === "citation_info") {
|
||||
const citationInfo = packetObj as {
|
||||
type: "citation_info";
|
||||
citation_number: number;
|
||||
document_id: string;
|
||||
};
|
||||
citationsPerModel[modelIndex] = {
|
||||
...(citationsPerModel[modelIndex] || {}),
|
||||
[citationInfo.citation_number]: citationInfo.document_id,
|
||||
};
|
||||
} else if (packetObj.type === "message_start") {
|
||||
const messageStart = packetObj as MessageStart;
|
||||
if (messageStart.final_documents) {
|
||||
documentsPerModel[modelIndex] =
|
||||
messageStart.final_documents;
|
||||
if (modelIndex === 0 && initialAssistantNodes[0]) {
|
||||
updateSelectedNodeForDocDisplay(
|
||||
frozenSessionId,
|
||||
initialAssistantNodes[0].nodeId
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single model mode
|
||||
packets.push(typedPacket);
|
||||
packetsVersion++;
|
||||
|
||||
const packetObj = typedPacket.obj;
|
||||
|
||||
if (packetObj.type === "citation_info") {
|
||||
const citationInfo = packetObj as {
|
||||
type: "citation_info";
|
||||
citation_number: number;
|
||||
document_id: string;
|
||||
};
|
||||
citations = {
|
||||
...(citations || {}),
|
||||
[citationInfo.citation_number]: citationInfo.document_id,
|
||||
};
|
||||
} else if (packetObj.type === "message_start") {
|
||||
const messageStart = packetObj as MessageStart;
|
||||
if (messageStart.final_documents) {
|
||||
documents = messageStart.final_documents;
|
||||
updateSelectedNodeForDocDisplay(
|
||||
frozenSessionId,
|
||||
initialAgentNode.nodeId
|
||||
);
|
||||
}
|
||||
if (packetObj.type === "citation_info") {
|
||||
// Individual citation packet from backend streaming
|
||||
const citationInfo = packetObj as {
|
||||
type: "citation_info";
|
||||
citation_number: number;
|
||||
document_id: string;
|
||||
};
|
||||
// Incrementally build citations map
|
||||
citations = {
|
||||
...(citations || {}),
|
||||
[citationInfo.citation_number]: citationInfo.document_id,
|
||||
};
|
||||
} else if (packetObj.type === "message_start") {
|
||||
const messageStart = packetObj as MessageStart;
|
||||
if (messageStart.final_documents) {
|
||||
documents = messageStart.final_documents;
|
||||
updateSelectedNodeForDocDisplay(
|
||||
frozenSessionId,
|
||||
initialAgentNode.nodeId
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -974,48 +860,8 @@ export default function useChatController({
|
||||
parentMessage =
|
||||
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
|
||||
|
||||
// Build messages to upsert based on mode
|
||||
let messagesToUpsertInLoop: Message[];
|
||||
|
||||
if (isMultiModelMode) {
|
||||
// Multi-model mode: update user node + all N assistant nodes
|
||||
const updatedUserNode = {
|
||||
...initialUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
};
|
||||
|
||||
const updatedAssistantNodes = initialAssistantNodes.map(
|
||||
(node, idx) => ({
|
||||
...node,
|
||||
messageId: newAssistantMessageIds[idx] ?? undefined,
|
||||
message: "",
|
||||
type: "assistant" as const,
|
||||
retrievalType,
|
||||
query: query,
|
||||
documents: documentsPerModel[idx] || [],
|
||||
citations: citationsPerModel[idx] || {},
|
||||
files: [] as FileDescriptor[],
|
||||
toolCall: null,
|
||||
stackTrace: null,
|
||||
overridden_model: selectedModels?.[idx]?.displayName,
|
||||
modelDisplayName:
|
||||
modelDisplayNames[idx] ||
|
||||
selectedModels?.[idx]?.displayName ||
|
||||
null,
|
||||
stopReason: stopReason,
|
||||
packets: packetsPerModel[idx] || [],
|
||||
packetCount: packetsPerModel[idx]?.length || 0,
|
||||
})
|
||||
);
|
||||
|
||||
messagesToUpsertInLoop = [
|
||||
updatedUserNode,
|
||||
...updatedAssistantNodes,
|
||||
];
|
||||
} else {
|
||||
// Single model mode (existing logic)
|
||||
messagesToUpsertInLoop = [
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: [
|
||||
{
|
||||
...initialUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
@@ -1048,11 +894,8 @@ export default function useChatController({
|
||||
: undefined;
|
||||
})(),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: messagesToUpsertInLoop,
|
||||
],
|
||||
// Pass the latest map state
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
import { renderHook, act } from "@testing-library/react";
|
||||
import useMultiModelChat from "@/hooks/useMultiModelChat";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
|
||||
|
||||
// Mock buildLlmOptions — hook uses it internally for initialization.
|
||||
// Tests here focus on CRUD operations, not the initialization side-effect.
|
||||
jest.mock("@/refresh-components/popovers/LLMPopover", () => ({
|
||||
buildLlmOptions: jest.fn(() => []),
|
||||
}));
|
||||
|
||||
const makeLlmManager = (): LlmManager =>
|
||||
({
|
||||
llmProviders: [],
|
||||
currentLlm: { modelName: null, provider: null },
|
||||
isLoadingProviders: false,
|
||||
}) as unknown as LlmManager;
|
||||
|
||||
const makeModel = (provider: string, modelName: string): SelectedModel => ({
|
||||
name: provider,
|
||||
provider,
|
||||
modelName,
|
||||
displayName: `${provider}/${modelName}`,
|
||||
});
|
||||
|
||||
const GPT4 = makeModel("openai", "gpt-4");
|
||||
const CLAUDE = makeModel("anthropic", "claude-opus-4-6");
|
||||
const GEMINI = makeModel("google", "gemini-pro");
|
||||
const GPT4_TURBO = makeModel("openai", "gpt-4-turbo");
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// addModel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("addModel", () => {
|
||||
it("adds a model to an empty selection", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(1);
|
||||
expect(result.current.selectedModels[0]).toEqual(GPT4);
|
||||
});
|
||||
|
||||
it("does not add a duplicate model", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(GPT4); // duplicate
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("enforces MAX_MODELS (3) cap", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
result.current.addModel(GEMINI);
|
||||
result.current.addModel(GPT4_TURBO); // should be ignored
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(3);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// removeModel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("removeModel", () => {
|
||||
it("removes a model by index", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.removeModel(0); // remove GPT4
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(1);
|
||||
expect(result.current.selectedModels[0]).toEqual(CLAUDE);
|
||||
});
|
||||
|
||||
it("handles out-of-range index gracefully", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.removeModel(99); // no-op
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// replaceModel
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("replaceModel", () => {
|
||||
it("replaces the model at the given index", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.replaceModel(0, GEMINI);
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels[0]).toEqual(GEMINI);
|
||||
expect(result.current.selectedModels[1]).toEqual(CLAUDE);
|
||||
});
|
||||
|
||||
it("does not replace with a model already selected at another index", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.replaceModel(0, CLAUDE); // CLAUDE is already at index 1
|
||||
});
|
||||
|
||||
// Should be a no-op — GPT4 stays at index 0
|
||||
expect(result.current.selectedModels[0]).toEqual(GPT4);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isMultiModelActive
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("isMultiModelActive", () => {
|
||||
it("is false with zero models", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
expect(result.current.isMultiModelActive).toBe(false);
|
||||
});
|
||||
|
||||
it("is false with exactly one model", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
});
|
||||
|
||||
expect(result.current.isMultiModelActive).toBe(false);
|
||||
});
|
||||
|
||||
it("is true with two or more models", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
expect(result.current.isMultiModelActive).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildLlmOverrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("buildLlmOverrides", () => {
|
||||
it("returns empty array when no models selected", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
expect(result.current.buildLlmOverrides()).toEqual([]);
|
||||
});
|
||||
|
||||
it("maps selectedModels to LLMOverride format", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
const overrides = result.current.buildLlmOverrides();
|
||||
|
||||
expect(overrides).toHaveLength(2);
|
||||
expect(overrides[0]).toEqual({
|
||||
model_provider: "openai",
|
||||
model_version: "gpt-4",
|
||||
display_name: "openai/gpt-4",
|
||||
});
|
||||
expect(overrides[1]).toEqual({
|
||||
model_provider: "anthropic",
|
||||
model_version: "claude-opus-4-6",
|
||||
display_name: "anthropic/claude-opus-4-6",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// clearModels
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("clearModels", () => {
|
||||
it("empties the selection", () => {
|
||||
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
|
||||
|
||||
act(() => {
|
||||
result.current.addModel(GPT4);
|
||||
result.current.addModel(CLAUDE);
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.clearModels();
|
||||
});
|
||||
|
||||
expect(result.current.selectedModels).toHaveLength(0);
|
||||
expect(result.current.isMultiModelActive).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -1,191 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useCallback, useEffect, useMemo } from "react";
|
||||
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
|
||||
import { LLMOverride } from "@/app/app/services/lib";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { buildLlmOptions } from "@/refresh-components/popovers/LLMPopover";
|
||||
|
||||
const MAX_MODELS = 3;
|
||||
|
||||
export interface UseMultiModelChatReturn {
|
||||
/** Currently selected models for multi-model comparison. */
|
||||
selectedModels: SelectedModel[];
|
||||
/** Whether multi-model mode is active (>1 model selected). */
|
||||
isMultiModelActive: boolean;
|
||||
/** Add a model to the selection. */
|
||||
addModel: (model: SelectedModel) => void;
|
||||
/** Remove a model by index. */
|
||||
removeModel: (index: number) => void;
|
||||
/** Replace a model at a specific index with a new one. */
|
||||
replaceModel: (index: number, model: SelectedModel) => void;
|
||||
/** Clear all selected models. */
|
||||
clearModels: () => void;
|
||||
/** Build the LLMOverride[] array from selectedModels. */
|
||||
buildLlmOverrides: () => LLMOverride[];
|
||||
/**
|
||||
* Restore multi-model selection from model version strings (e.g. from chat history).
|
||||
* Matches against available llmOptions to reconstruct full SelectedModel objects.
|
||||
*/
|
||||
restoreFromModelNames: (modelNames: string[]) => void;
|
||||
/**
|
||||
* Switch to a single model by name (after user picks a preferred response).
|
||||
* Matches against llmOptions to find the full SelectedModel.
|
||||
*/
|
||||
selectSingleModel: (modelName: string) => void;
|
||||
}
|
||||
|
||||
export default function useMultiModelChat(
|
||||
llmManager: LlmManager
|
||||
): UseMultiModelChatReturn {
|
||||
const [selectedModels, setSelectedModels] = useState<SelectedModel[]>([]);
|
||||
const [defaultInitialized, setDefaultInitialized] = useState(false);
|
||||
|
||||
// Initialize with the default model from llmManager once providers load
|
||||
const llmOptions = useMemo(
|
||||
() =>
|
||||
llmManager.llmProviders ? buildLlmOptions(llmManager.llmProviders) : [],
|
||||
[llmManager.llmProviders]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (defaultInitialized) return;
|
||||
if (llmOptions.length === 0) return;
|
||||
const { currentLlm } = llmManager;
|
||||
// Don't initialize if currentLlm hasn't loaded yet
|
||||
if (!currentLlm.modelName) return;
|
||||
const match = llmOptions.find(
|
||||
(opt) =>
|
||||
opt.provider === currentLlm.provider &&
|
||||
opt.modelName === currentLlm.modelName
|
||||
);
|
||||
if (match) {
|
||||
setSelectedModels([
|
||||
{
|
||||
name: match.name,
|
||||
provider: match.provider,
|
||||
modelName: match.modelName,
|
||||
displayName: match.displayName,
|
||||
},
|
||||
]);
|
||||
setDefaultInitialized(true);
|
||||
}
|
||||
}, [llmOptions, llmManager.currentLlm, defaultInitialized, llmManager]);
|
||||
|
||||
const isMultiModelActive = selectedModels.length > 1;
|
||||
|
||||
const addModel = useCallback((model: SelectedModel) => {
|
||||
setSelectedModels((prev) => {
|
||||
if (prev.length >= MAX_MODELS) return prev;
|
||||
if (
|
||||
prev.some(
|
||||
(m) =>
|
||||
m.provider === model.provider && m.modelName === model.modelName
|
||||
)
|
||||
) {
|
||||
return prev;
|
||||
}
|
||||
return [...prev, model];
|
||||
});
|
||||
}, []);
|
||||
|
||||
const removeModel = useCallback((index: number) => {
|
||||
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
|
||||
}, []);
|
||||
|
||||
const replaceModel = useCallback((index: number, model: SelectedModel) => {
|
||||
setSelectedModels((prev) => {
|
||||
// Don't replace with a model that's already selected elsewhere
|
||||
if (
|
||||
prev.some(
|
||||
(m, i) =>
|
||||
i !== index &&
|
||||
m.provider === model.provider &&
|
||||
m.modelName === model.modelName
|
||||
)
|
||||
) {
|
||||
return prev;
|
||||
}
|
||||
const next = [...prev];
|
||||
next[index] = model;
|
||||
return next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
const clearModels = useCallback(() => {
|
||||
setSelectedModels([]);
|
||||
}, []);
|
||||
|
||||
const restoreFromModelNames = useCallback(
|
||||
(modelNames: string[]) => {
|
||||
if (modelNames.length < 2 || llmOptions.length === 0) return;
|
||||
const restored: SelectedModel[] = [];
|
||||
for (const name of modelNames) {
|
||||
// Try matching by modelName (raw version string like "claude-opus-4-6")
|
||||
// or by displayName (friendly name like "Claude Opus 4.6")
|
||||
const match = llmOptions.find(
|
||||
(opt) =>
|
||||
opt.modelName === name ||
|
||||
opt.displayName === name ||
|
||||
opt.name === name
|
||||
);
|
||||
if (match) {
|
||||
restored.push({
|
||||
name: match.name,
|
||||
provider: match.provider,
|
||||
modelName: match.modelName,
|
||||
displayName: match.displayName,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (restored.length >= 2) {
|
||||
setSelectedModels(restored);
|
||||
setDefaultInitialized(true);
|
||||
}
|
||||
},
|
||||
[llmOptions]
|
||||
);
|
||||
|
||||
const selectSingleModel = useCallback(
|
||||
(modelName: string) => {
|
||||
if (llmOptions.length === 0) return;
|
||||
const match = llmOptions.find(
|
||||
(opt) =>
|
||||
opt.modelName === modelName ||
|
||||
opt.displayName === modelName ||
|
||||
opt.name === modelName
|
||||
);
|
||||
if (match) {
|
||||
setSelectedModels([
|
||||
{
|
||||
name: match.name,
|
||||
provider: match.provider,
|
||||
modelName: match.modelName,
|
||||
displayName: match.displayName,
|
||||
},
|
||||
]);
|
||||
}
|
||||
},
|
||||
[llmOptions]
|
||||
);
|
||||
|
||||
const buildLlmOverrides = useCallback((): LLMOverride[] => {
|
||||
return selectedModels.map((m) => ({
|
||||
model_provider: m.name,
|
||||
model_version: m.modelName,
|
||||
display_name: m.displayName,
|
||||
}));
|
||||
}, [selectedModels]);
|
||||
|
||||
return {
|
||||
selectedModels,
|
||||
isMultiModelActive,
|
||||
addModel,
|
||||
removeModel,
|
||||
replaceModel,
|
||||
clearModels,
|
||||
buildLlmOverrides,
|
||||
restoreFromModelNames,
|
||||
selectSingleModel,
|
||||
};
|
||||
}
|
||||
@@ -32,12 +32,6 @@ export interface FrostedDivProps extends React.HTMLAttributes<HTMLDivElement> {
|
||||
* Additional classes for the frost overlay element itself
|
||||
*/
|
||||
overlayClassName?: string;
|
||||
|
||||
/**
|
||||
* Additional classes for the outermost wrapper div (the `relative` container).
|
||||
* Useful for width propagation (e.g. `w-full`).
|
||||
*/
|
||||
wrapperClassName?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -66,14 +60,13 @@ export default function FrostedDiv({
|
||||
backdropBlur = "6px",
|
||||
borderRadius = "1rem",
|
||||
overlayClassName,
|
||||
wrapperClassName,
|
||||
className,
|
||||
style,
|
||||
children,
|
||||
...props
|
||||
}: FrostedDivProps) {
|
||||
return (
|
||||
<div className={cn("relative", wrapperClassName)}>
|
||||
<div className="relative">
|
||||
{/* Frost effect overlay - positioned behind content with bloom extending outward */}
|
||||
<div
|
||||
className={cn("absolute pointer-events-none", overlayClassName)}
|
||||
|
||||
@@ -1,469 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useMemo, useRef } from "react";
|
||||
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button } from "@opal/components";
|
||||
import {
|
||||
SvgCheck,
|
||||
SvgChevronDown,
|
||||
SvgChevronRight,
|
||||
SvgColumn,
|
||||
SvgPlusCircle,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import {
|
||||
LLMOption,
|
||||
LLMOptionGroup,
|
||||
} from "@/refresh-components/popovers/interfaces";
|
||||
import {
|
||||
buildLlmOptions,
|
||||
groupLlmOptions,
|
||||
} from "@/refresh-components/popovers/LLMPopover";
|
||||
import {
|
||||
Accordion,
|
||||
AccordionContent,
|
||||
AccordionItem,
|
||||
AccordionTrigger,
|
||||
} from "@/components/ui/accordion";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const MAX_MODELS = 3;
|
||||
|
||||
export interface SelectedModel {
|
||||
name: string;
|
||||
provider: string;
|
||||
modelName: string;
|
||||
displayName: string;
|
||||
}
|
||||
|
||||
export interface ModelSelectorProps {
|
||||
llmManager: LlmManager;
|
||||
selectedModels: SelectedModel[];
|
||||
onAdd: (model: SelectedModel) => void;
|
||||
onRemove: (index: number) => void;
|
||||
onReplace: (index: number, model: SelectedModel) => void;
|
||||
}
|
||||
|
||||
/** Vertical 1px divider between model bar elements */
|
||||
function BarDivider() {
|
||||
return <div className="h-9 w-px bg-border-01 shrink-0" />;
|
||||
}
|
||||
|
||||
/** Individual model pill in the model bar */
|
||||
function ModelPill({
|
||||
model,
|
||||
isMultiModel,
|
||||
onRemove,
|
||||
onClick,
|
||||
}: {
|
||||
model: SelectedModel;
|
||||
isMultiModel: boolean;
|
||||
onRemove?: () => void;
|
||||
onClick?: () => void;
|
||||
}) {
|
||||
const ProviderIcon = getProviderIcon(model.provider, model.modelName);
|
||||
|
||||
return (
|
||||
<div
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onClick={onClick}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" || e.key === " ") {
|
||||
e.preventDefault();
|
||||
onClick?.();
|
||||
}
|
||||
}}
|
||||
className={cn(
|
||||
"flex items-center gap-0.5 rounded-12 p-2 shrink-0 cursor-pointer",
|
||||
"hover:bg-background-tint-02 transition-colors",
|
||||
isMultiModel && "bg-background-tint-02"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center justify-center size-5 shrink-0 p-0.5">
|
||||
<ProviderIcon size={16} />
|
||||
</div>
|
||||
<Text mainUiAction text04 nowrap className="px-1">
|
||||
{model.displayName}
|
||||
</Text>
|
||||
{isMultiModel ? (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.();
|
||||
}}
|
||||
className="flex items-center justify-center size-4 shrink-0 hover:opacity-70"
|
||||
>
|
||||
<SvgX className="size-4 stroke-text-03" />
|
||||
</button>
|
||||
) : (
|
||||
<SvgChevronDown className="size-4 stroke-text-03 shrink-0" />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/** Model item row inside the add-model popover */
|
||||
function ModelItem({
|
||||
option,
|
||||
isSelected,
|
||||
isDisabled,
|
||||
onToggle,
|
||||
}: {
|
||||
option: LLMOption;
|
||||
isSelected: boolean;
|
||||
isDisabled: boolean;
|
||||
onToggle: () => void;
|
||||
}) {
|
||||
const ProviderIcon = getProviderIcon(option.provider, option.modelName);
|
||||
|
||||
// Build subtitle from model capabilities
|
||||
const subtitle = useMemo(() => {
|
||||
const parts: string[] = [];
|
||||
if (option.supportsReasoning) parts.push("reasoning");
|
||||
if (option.supportsImageInput) parts.push("multi-modal");
|
||||
if (parts.length === 0 && option.modelName) return option.modelName;
|
||||
return parts.join(", ");
|
||||
}, [option]);
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
disabled={isDisabled}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"flex items-center gap-1.5 w-full rounded-08 p-1.5 text-left transition-colors",
|
||||
isSelected ? "bg-action-link-01" : "hover:bg-background-tint-02",
|
||||
isDisabled && !isSelected && "opacity-50 cursor-not-allowed"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center justify-center size-5 shrink-0 p-0.5">
|
||||
<ProviderIcon size={16} />
|
||||
</div>
|
||||
<div className="flex flex-col flex-1 min-w-0">
|
||||
<Text
|
||||
mainUiAction
|
||||
nowrap
|
||||
className={cn(isSelected ? "text-action-link-03" : "text-text-04")}
|
||||
>
|
||||
{option.displayName}
|
||||
</Text>
|
||||
{subtitle && (
|
||||
<Text secondaryBody text03 nowrap>
|
||||
{subtitle}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
{isSelected && (
|
||||
<Text secondaryBody nowrap className="text-action-link-05 shrink-0">
|
||||
Added
|
||||
</Text>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
export default function ModelSelector({
|
||||
llmManager,
|
||||
selectedModels,
|
||||
onAdd,
|
||||
onRemove,
|
||||
onReplace,
|
||||
}: ModelSelectorProps) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const scrollContainerRef = useRef<HTMLDivElement>(null);
|
||||
// null = add mode (via + button), number = replace mode (via pill click)
|
||||
const [replacingIndex, setReplacingIndex] = useState<number | null>(null);
|
||||
|
||||
const isMultiModel = selectedModels.length > 1;
|
||||
const atMax = selectedModels.length >= MAX_MODELS;
|
||||
|
||||
const llmOptions = useMemo(
|
||||
() => buildLlmOptions(llmManager.llmProviders),
|
||||
[llmManager.llmProviders]
|
||||
);
|
||||
|
||||
const selectedKeys = useMemo(
|
||||
() => new Set(selectedModels.map((m) => `${m.provider}:${m.modelName}`)),
|
||||
[selectedModels]
|
||||
);
|
||||
|
||||
const filteredOptions = useMemo(() => {
|
||||
if (!searchQuery.trim()) return llmOptions;
|
||||
const query = searchQuery.toLowerCase();
|
||||
return llmOptions.filter(
|
||||
(opt) =>
|
||||
opt.displayName.toLowerCase().includes(query) ||
|
||||
opt.modelName.toLowerCase().includes(query) ||
|
||||
(opt.vendor && opt.vendor.toLowerCase().includes(query))
|
||||
);
|
||||
}, [llmOptions, searchQuery]);
|
||||
|
||||
const groupedOptions = useMemo(
|
||||
() => groupLlmOptions(filteredOptions),
|
||||
[filteredOptions]
|
||||
);
|
||||
|
||||
const isSearching = searchQuery.trim().length > 0;
|
||||
|
||||
// In replace mode, other selected models (not the one being replaced) are disabled
|
||||
const otherSelectedKeys = useMemo(() => {
|
||||
if (replacingIndex === null) return new Set<string>();
|
||||
return new Set(
|
||||
selectedModels
|
||||
.filter((_, i) => i !== replacingIndex)
|
||||
.map((m) => `${m.provider}:${m.modelName}`)
|
||||
);
|
||||
}, [selectedModels, replacingIndex]);
|
||||
|
||||
// Current model at the replacing index (shows as "selected" in replace mode)
|
||||
const replacingKey = useMemo(() => {
|
||||
if (replacingIndex === null) return null;
|
||||
const m = selectedModels[replacingIndex];
|
||||
return m ? `${m.provider}:${m.modelName}` : null;
|
||||
}, [selectedModels, replacingIndex]);
|
||||
|
||||
const getItemState = (optKey: string) => {
|
||||
if (replacingIndex !== null) {
|
||||
// Replace mode
|
||||
return {
|
||||
isSelected: optKey === replacingKey,
|
||||
isDisabled: otherSelectedKeys.has(optKey),
|
||||
};
|
||||
}
|
||||
// Add mode
|
||||
return {
|
||||
isSelected: selectedKeys.has(optKey),
|
||||
isDisabled: !selectedKeys.has(optKey) && atMax,
|
||||
};
|
||||
};
|
||||
|
||||
const handleSelectModel = (option: LLMOption) => {
|
||||
const model: SelectedModel = {
|
||||
name: option.name,
|
||||
provider: option.provider,
|
||||
modelName: option.modelName,
|
||||
displayName: option.displayName,
|
||||
};
|
||||
|
||||
if (replacingIndex !== null) {
|
||||
// Replace mode: swap the model at the clicked pill's index
|
||||
onReplace(replacingIndex, model);
|
||||
setOpen(false);
|
||||
setReplacingIndex(null);
|
||||
setSearchQuery("");
|
||||
return;
|
||||
}
|
||||
|
||||
// Add mode: toggle (add/remove)
|
||||
const key = `${option.provider}:${option.modelName}`;
|
||||
const existingIndex = selectedModels.findIndex(
|
||||
(m) => `${m.provider}:${m.modelName}` === key
|
||||
);
|
||||
if (existingIndex >= 0) {
|
||||
onRemove(existingIndex);
|
||||
} else if (!atMax) {
|
||||
onAdd(model);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOpenChange = (nextOpen: boolean) => {
|
||||
setOpen(nextOpen);
|
||||
if (!nextOpen) {
|
||||
setReplacingIndex(null);
|
||||
setSearchQuery("");
|
||||
}
|
||||
};
|
||||
|
||||
const handlePillClick = (index: number) => {
|
||||
setReplacingIndex(index);
|
||||
setOpen(true);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-end gap-1 p-1">
|
||||
{/* (+) Add model button — hidden at max models */}
|
||||
{!atMax && (
|
||||
<Popover open={open} onOpenChange={handleOpenChange}>
|
||||
<Popover.Trigger asChild>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgPlusCircle}
|
||||
size="sm"
|
||||
tooltip="Add Model"
|
||||
/>
|
||||
</Popover.Trigger>
|
||||
|
||||
<Popover.Content side="top" align="start" width="lg">
|
||||
<Section gap={0.25}>
|
||||
<InputTypeIn
|
||||
leftSearchIcon
|
||||
variant="internal"
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
placeholder="Search models..."
|
||||
/>
|
||||
|
||||
<PopoverMenu scrollContainerRef={scrollContainerRef}>
|
||||
{groupedOptions.length === 0
|
||||
? [
|
||||
<div key="empty" className="py-3 px-2">
|
||||
<Text secondaryBody text03>
|
||||
No models found
|
||||
</Text>
|
||||
</div>,
|
||||
]
|
||||
: groupedOptions.length === 1
|
||||
? [
|
||||
<div key="single" className="flex flex-col gap-0.5">
|
||||
{groupedOptions[0]!.options.map((opt) => {
|
||||
const key = `${opt.provider}:${opt.modelName}`;
|
||||
const state = getItemState(key);
|
||||
return (
|
||||
<ModelItem
|
||||
key={opt.modelName}
|
||||
option={opt}
|
||||
isSelected={state.isSelected}
|
||||
isDisabled={state.isDisabled}
|
||||
onToggle={() => handleSelectModel(opt)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>,
|
||||
]
|
||||
: [
|
||||
<ModelGroupAccordion
|
||||
key="accordion"
|
||||
groups={groupedOptions}
|
||||
isSearching={isSearching}
|
||||
getItemState={getItemState}
|
||||
onToggle={handleSelectModel}
|
||||
/>,
|
||||
]}
|
||||
</PopoverMenu>
|
||||
|
||||
<div className="border-t border-border-01 mt-1 pt-1">
|
||||
<button
|
||||
type="button"
|
||||
className="flex items-center gap-1.5 w-full rounded-08 p-1.5 text-left hover:bg-background-tint-02 transition-colors"
|
||||
>
|
||||
<SvgColumn className="size-5 stroke-text-03 shrink-0" />
|
||||
<Text mainUiAction text04>
|
||||
Compare Model
|
||||
</Text>
|
||||
</button>
|
||||
</div>
|
||||
</Section>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
)}
|
||||
|
||||
{/* Divider + model pills */}
|
||||
{selectedModels.length > 0 && (
|
||||
<>
|
||||
<BarDivider />
|
||||
{selectedModels.map((model, index) => (
|
||||
<div
|
||||
key={`${model.provider}:${model.modelName}`}
|
||||
className="flex items-center gap-1"
|
||||
>
|
||||
{index > 0 && <BarDivider />}
|
||||
<ModelPill
|
||||
model={model}
|
||||
isMultiModel={isMultiModel}
|
||||
onRemove={() => onRemove(index)}
|
||||
onClick={() => handlePillClick(index)}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ModelGroupAccordionProps {
|
||||
groups: LLMOptionGroup[];
|
||||
isSearching: boolean;
|
||||
getItemState: (key: string) => { isSelected: boolean; isDisabled: boolean };
|
||||
onToggle: (option: LLMOption) => void;
|
||||
}
|
||||
|
||||
function ModelGroupAccordion({
|
||||
groups,
|
||||
isSearching,
|
||||
getItemState,
|
||||
onToggle,
|
||||
}: ModelGroupAccordionProps) {
|
||||
const allKeys = groups.map((g) => g.key);
|
||||
const [expandedGroups, setExpandedGroups] = useState<string[]>([
|
||||
allKeys[0] ?? "",
|
||||
]);
|
||||
|
||||
const effectiveExpanded = isSearching ? allKeys : expandedGroups;
|
||||
|
||||
return (
|
||||
<Accordion
|
||||
type="multiple"
|
||||
value={effectiveExpanded}
|
||||
onValueChange={(value) => {
|
||||
if (!isSearching) setExpandedGroups(value);
|
||||
}}
|
||||
className="w-full flex flex-col"
|
||||
>
|
||||
{groups.map((group) => {
|
||||
const isExpanded = effectiveExpanded.includes(group.key);
|
||||
return (
|
||||
<AccordionItem
|
||||
key={group.key}
|
||||
value={group.key}
|
||||
className="border-none pt-1"
|
||||
>
|
||||
<AccordionTrigger className="flex items-center rounded-08 hover:no-underline hover:bg-background-tint-02 group [&>svg]:hidden w-full py-1">
|
||||
<div className="flex items-center gap-1 shrink-0">
|
||||
<div className="flex items-center justify-center size-5 shrink-0">
|
||||
<group.Icon size={16} />
|
||||
</div>
|
||||
<Text secondaryBody text03 nowrap className="px-0.5">
|
||||
{group.displayName}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="flex-1" />
|
||||
<div className="flex items-center justify-center size-6 shrink-0">
|
||||
{isExpanded ? (
|
||||
<SvgChevronDown className="h-4 w-4 stroke-text-04 shrink-0" />
|
||||
) : (
|
||||
<SvgChevronRight className="h-4 w-4 stroke-text-04 shrink-0" />
|
||||
)}
|
||||
</div>
|
||||
</AccordionTrigger>
|
||||
<AccordionContent className="pb-0 pt-0">
|
||||
<div className="flex flex-col gap-0.5">
|
||||
{group.options.map((opt) => {
|
||||
const key = `${opt.provider}:${opt.modelName}`;
|
||||
const state = getItemState(key);
|
||||
return (
|
||||
<ModelItem
|
||||
key={key}
|
||||
option={opt}
|
||||
isSelected={state.isSelected}
|
||||
isDisabled={state.isDisabled}
|
||||
onToggle={() => onToggle(opt)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</AccordionContent>
|
||||
</AccordionItem>
|
||||
);
|
||||
})}
|
||||
</Accordion>
|
||||
);
|
||||
}
|
||||
@@ -15,8 +15,6 @@ import {
|
||||
} from "@/providers/SettingsProvider";
|
||||
import Dropzone from "react-dropzone";
|
||||
import AppInputBar, { AppInputBarHandle } from "@/sections/input/AppInputBar";
|
||||
import ModelSelector from "@/refresh-components/popovers/ModelSelector";
|
||||
import useMultiModelChat from "@/hooks/useMultiModelChat";
|
||||
import useChatSessions from "@/hooks/useChatSessions";
|
||||
import useCCPairs from "@/hooks/useCCPairs";
|
||||
import useTags from "@/hooks/useTags";
|
||||
@@ -66,7 +64,6 @@ import { SvgChevronDown, SvgFileText } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { useAppSidebarContext } from "@/providers/AppSidebarProvider";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import WelcomeMessage from "@/app/app/components/WelcomeMessage";
|
||||
import ChatUI from "@/sections/chat/ChatUI";
|
||||
@@ -367,30 +364,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
const autoScrollEnabled = user?.preferences?.auto_scroll !== false;
|
||||
const isStreaming = currentChatState === "streaming";
|
||||
|
||||
const multiModel = useMultiModelChat(llmManager);
|
||||
|
||||
// Auto-fold sidebar when multi-model is active (needs full width)
|
||||
const { folded: sidebarFolded, setFolded: setSidebarFolded } =
|
||||
useAppSidebarContext();
|
||||
const preMultiModelFoldedRef = useRef<boolean | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
multiModel.isMultiModelActive &&
|
||||
preMultiModelFoldedRef.current === null
|
||||
) {
|
||||
preMultiModelFoldedRef.current = sidebarFolded;
|
||||
setSidebarFolded(true);
|
||||
} else if (
|
||||
!multiModel.isMultiModelActive &&
|
||||
preMultiModelFoldedRef.current !== null
|
||||
) {
|
||||
setSidebarFolded(preMultiModelFoldedRef.current);
|
||||
preMultiModelFoldedRef.current = null;
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [multiModel.isMultiModelActive]);
|
||||
|
||||
const {
|
||||
onSubmit,
|
||||
stopGenerating,
|
||||
@@ -490,9 +463,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
message,
|
||||
currentMessageFiles,
|
||||
deepResearch: deepResearchEnabledForCurrentWorkflow,
|
||||
selectedModels: multiModel.isMultiModelActive
|
||||
? multiModel.selectedModels
|
||||
: undefined,
|
||||
});
|
||||
if (showOnboarding || !onboardingDismissed) {
|
||||
finishOnboarding();
|
||||
@@ -503,8 +473,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
onSubmit,
|
||||
currentMessageFiles,
|
||||
deepResearchEnabledForCurrentWorkflow,
|
||||
multiModel.isMultiModelActive,
|
||||
multiModel.selectedModels,
|
||||
showOnboarding,
|
||||
onboardingDismissed,
|
||||
finishOnboarding,
|
||||
@@ -543,9 +511,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
message,
|
||||
currentMessageFiles,
|
||||
deepResearch: deepResearchEnabledForCurrentWorkflow,
|
||||
selectedModels: multiModel.isMultiModelActive
|
||||
? multiModel.selectedModels
|
||||
: undefined,
|
||||
});
|
||||
if (showOnboarding || !onboardingDismissed) {
|
||||
finishOnboarding();
|
||||
@@ -570,7 +535,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
showOnboarding,
|
||||
onboardingDismissed,
|
||||
finishOnboarding,
|
||||
multiModel,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -711,7 +675,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
style={gridStyle}
|
||||
>
|
||||
{/* ── Top row: ChatUI / WelcomeMessage / ProjectUI ── */}
|
||||
<div className="row-start-1 min-h-0 overflow-y-hidden flex flex-col items-center">
|
||||
<div className="row-start-1 min-h-0 overflow-hidden flex flex-col items-center">
|
||||
{/* ChatUI */}
|
||||
<Fade
|
||||
show={
|
||||
@@ -740,7 +704,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
stopGenerating={stopGenerating}
|
||||
onResubmit={handleResubmitLastMessage}
|
||||
anchorNodeId={anchorNodeId}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
/>
|
||||
</ChatScrollContainer>
|
||||
</Fade>
|
||||
@@ -767,16 +730,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
<WelcomeMessage
|
||||
agent={liveAgent}
|
||||
isDefaultAgent={isDefaultAgent}
|
||||
hideTitle={multiModel.selectedModels.length >= 3}
|
||||
rightChildren={
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
<Spacer rem={1.5} />
|
||||
</Fade>
|
||||
@@ -838,15 +791,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
isSearch ? "h-[14px]" : "h-0"
|
||||
)}
|
||||
/>
|
||||
{appFocus.isChat() && (
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
)}
|
||||
<AppInputBar
|
||||
ref={chatInputBarRef}
|
||||
deepResearchEnabled={
|
||||
|
||||
@@ -347,19 +347,18 @@ const ChatScrollContainer = React.memo(
|
||||
const contentMask = buildContentMask();
|
||||
|
||||
return (
|
||||
<div className="flex flex-col flex-1 min-h-0 w-full relative overflow-y-hidden mb-1">
|
||||
<div className="flex flex-col flex-1 min-h-0 w-full relative overflow-hidden mb-1">
|
||||
<div
|
||||
key={sessionId}
|
||||
ref={scrollContainerRef}
|
||||
data-testid="chat-scroll-container"
|
||||
className={cn(
|
||||
"flex flex-col flex-1 min-h-0 overflow-y-auto overflow-x-auto",
|
||||
"flex flex-col flex-1 min-h-0 overflow-y-auto overflow-x-hidden",
|
||||
hideScrollbar ? "no-scrollbar" : "default-scrollbar"
|
||||
)}
|
||||
onScroll={handleScroll}
|
||||
style={{
|
||||
scrollbarGutter: "stable both-edges",
|
||||
overflowX: "auto",
|
||||
// Apply mask to fade content opacity at edges
|
||||
maskImage: contentMask,
|
||||
WebkitMaskImage: contentMask,
|
||||
@@ -367,7 +366,7 @@ const ChatScrollContainer = React.memo(
|
||||
>
|
||||
<div
|
||||
ref={contentWrapperRef}
|
||||
className="min-w-full flex-1 flex flex-col items-center px-4"
|
||||
className="w-full flex-1 flex flex-col items-center px-4"
|
||||
data-scroll-ready={isScrollReady}
|
||||
style={{
|
||||
visibility: isScrollReady ? "visible" : "hidden",
|
||||
|
||||
@@ -8,10 +8,7 @@ import { ErrorBanner } from "@/app/app/message/Resubmit";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
|
||||
import { LlmDescriptor, LlmManager } from "@/lib/hooks";
|
||||
import AgentMessage from "@/app/app/message/messageComponents/AgentMessage";
|
||||
import MultiModelResponseView, {
|
||||
MultiModelResponse,
|
||||
} from "@/app/app/message/MultiModelResponseView";
|
||||
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import DynamicBottomSpacer from "@/components/chat/DynamicBottomSpacer";
|
||||
import {
|
||||
useCurrentMessageHistory,
|
||||
@@ -20,8 +17,6 @@ import {
|
||||
useUncaughtError,
|
||||
} from "@/app/app/stores/useChatSessionStore";
|
||||
|
||||
const MSG_MAX_W = "max-w-[720px] min-w-[400px]";
|
||||
|
||||
export interface ChatUIProps {
|
||||
liveAgent: MinimalPersonaSnapshot;
|
||||
llmManager: LlmManager;
|
||||
@@ -42,7 +37,6 @@ export interface ChatUIProps {
|
||||
forceSearch?: boolean;
|
||||
};
|
||||
forceSearch?: boolean;
|
||||
selectedModels?: SelectedModel[];
|
||||
}) => Promise<void>;
|
||||
deepResearchEnabled: boolean;
|
||||
currentMessageFiles: any[];
|
||||
@@ -54,9 +48,6 @@ export interface ChatUIProps {
|
||||
* Used by DynamicBottomSpacer to position the push-up effect.
|
||||
*/
|
||||
anchorNodeId?: number;
|
||||
|
||||
/** Currently selected models for multi-model comparison. */
|
||||
selectedModels: SelectedModel[];
|
||||
}
|
||||
|
||||
const ChatUI = React.memo(
|
||||
@@ -71,7 +62,6 @@ const ChatUI = React.memo(
|
||||
currentMessageFiles,
|
||||
onResubmit,
|
||||
anchorNodeId,
|
||||
selectedModels,
|
||||
}: ChatUIProps) => {
|
||||
// Get messages and error state from store
|
||||
const messages = useCurrentMessageHistory();
|
||||
@@ -86,11 +76,9 @@ const ChatUI = React.memo(
|
||||
const onSubmitRef = useRef(onSubmit);
|
||||
const deepResearchEnabledRef = useRef(deepResearchEnabled);
|
||||
const currentMessageFilesRef = useRef(currentMessageFiles);
|
||||
const selectedModelsRef = useRef(selectedModels);
|
||||
onSubmitRef.current = onSubmit;
|
||||
deepResearchEnabledRef.current = deepResearchEnabled;
|
||||
currentMessageFilesRef.current = currentMessageFiles;
|
||||
selectedModelsRef.current = selectedModels;
|
||||
|
||||
const createRegenerator = useCallback(
|
||||
(regenerationRequest: {
|
||||
@@ -115,72 +103,19 @@ const ChatUI = React.memo(
|
||||
|
||||
const handleEditWithMessageId = useCallback(
|
||||
(editedContent: string, msgId: number) => {
|
||||
const models = selectedModelsRef.current;
|
||||
onSubmitRef.current({
|
||||
message: editedContent,
|
||||
messageIdToResend: msgId,
|
||||
currentMessageFiles: [],
|
||||
deepResearch: deepResearchEnabledRef.current,
|
||||
selectedModels: models.length >= 2 ? models : undefined,
|
||||
});
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
// Helper to check if a user message has multi-model responses
|
||||
const getMultiModelResponses = useCallback(
|
||||
(userMessage: Message): MultiModelResponse[] | null => {
|
||||
if (!messageTree) return null;
|
||||
|
||||
const childrenNodeIds = userMessage.childrenNodeIds || [];
|
||||
if (childrenNodeIds.length < 2) return null;
|
||||
|
||||
const childMessages = childrenNodeIds
|
||||
.map((nodeId) => messageTree.get(nodeId))
|
||||
.filter(
|
||||
(msg): msg is Message =>
|
||||
msg !== undefined && msg.type === "assistant"
|
||||
);
|
||||
|
||||
if (childMessages.length < 2) return null;
|
||||
|
||||
// Distinguish multi-model from regenerations: multi-model messages
|
||||
// have modelDisplayName or overridden_model set. Regenerations don't.
|
||||
// During streaming, overridden_model is set. On reload, modelDisplayName is set.
|
||||
const multiModelChildren = childMessages.filter(
|
||||
(msg) => msg.modelDisplayName || msg.overridden_model
|
||||
);
|
||||
if (multiModelChildren.length < 2) return null;
|
||||
|
||||
const latestChildNodeId = userMessage.latestChildNodeId;
|
||||
|
||||
return childMessages.map((msg, idx) => {
|
||||
// During streaming, overridden_model has the friendly display name.
|
||||
// On reload from history, modelDisplayName has the DB-stored name.
|
||||
const name = msg.overridden_model || msg.modelDisplayName || "Model";
|
||||
return {
|
||||
modelIndex: idx,
|
||||
provider: "",
|
||||
modelName: name,
|
||||
displayName: name,
|
||||
packets: msg.packets || [],
|
||||
packetCount: msg.packetCount || msg.packets?.length || 0,
|
||||
nodeId: msg.nodeId,
|
||||
messageId: msg.messageId,
|
||||
isHighlighted: msg.nodeId === latestChildNodeId,
|
||||
currentFeedback: msg.currentFeedback,
|
||||
isGenerating: msg.is_generating || false,
|
||||
};
|
||||
});
|
||||
},
|
||||
[messageTree]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* No max-width on container — individual messages control their own width.
|
||||
Multi-model responses use full width while normal messages stay centered. */}
|
||||
<div className="flex flex-col w-full h-full pt-4 pb-8 pr-1 gap-12">
|
||||
<div className="flex flex-col w-full max-w-[var(--app-page-main-content-width)] h-full pt-4 pb-8 pr-1 gap-12">
|
||||
{messages.map((message, i) => {
|
||||
const messageReactComponentKey = `message-${message.nodeId}`;
|
||||
const parentMessage = message.parentNodeId
|
||||
@@ -190,63 +125,32 @@ const ChatUI = React.memo(
|
||||
const nextMessage =
|
||||
messages.length > i + 1 ? messages[i + 1] : null;
|
||||
|
||||
// Check for multi-model responses
|
||||
const multiModelResponses = getMultiModelResponses(message);
|
||||
|
||||
return (
|
||||
<div
|
||||
id={messageReactComponentKey}
|
||||
key={messageReactComponentKey}
|
||||
className="flex flex-col gap-12 w-full"
|
||||
>
|
||||
{/* Human message stays at normal chat width */}
|
||||
<div className={`w-full ${MSG_MAX_W} self-center`}>
|
||||
<HumanMessage
|
||||
disableSwitchingForStreaming={
|
||||
(nextMessage && nextMessage.is_generating) || false
|
||||
}
|
||||
stopGenerating={stopGenerating}
|
||||
content={message.message}
|
||||
files={message.files}
|
||||
messageId={message.messageId}
|
||||
nodeId={message.nodeId}
|
||||
onEdit={handleEditWithMessageId}
|
||||
otherMessagesCanSwitchTo={
|
||||
parentMessage?.childrenNodeIds ?? emptyChildrenIds
|
||||
}
|
||||
onMessageSelection={onMessageSelection}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Multi-model response uses full width */}
|
||||
{multiModelResponses && (
|
||||
<MultiModelResponseView
|
||||
responses={multiModelResponses}
|
||||
chatState={{
|
||||
agent: liveAgent,
|
||||
docs: emptyDocs,
|
||||
citations: undefined,
|
||||
setPresentingDocument,
|
||||
overriddenModel: llmManager.currentLlm?.modelName,
|
||||
}}
|
||||
llmManager={llmManager}
|
||||
onRegenerate={createRegenerator}
|
||||
parentMessage={message}
|
||||
otherMessagesCanSwitchTo={
|
||||
parentMessage?.childrenNodeIds ?? emptyChildrenIds
|
||||
}
|
||||
onMessageSelection={onMessageSelection}
|
||||
/>
|
||||
)}
|
||||
<HumanMessage
|
||||
disableSwitchingForStreaming={
|
||||
(nextMessage && nextMessage.is_generating) || false
|
||||
}
|
||||
stopGenerating={stopGenerating}
|
||||
content={message.message}
|
||||
files={message.files}
|
||||
messageId={message.messageId}
|
||||
nodeId={message.nodeId}
|
||||
onEdit={handleEditWithMessageId}
|
||||
otherMessagesCanSwitchTo={
|
||||
parentMessage?.childrenNodeIds ?? emptyChildrenIds
|
||||
}
|
||||
onMessageSelection={onMessageSelection}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
} else if (message.type === "assistant") {
|
||||
if ((error || loadError) && i === messages.length - 1) {
|
||||
return (
|
||||
<div
|
||||
key={`error-${message.nodeId}`}
|
||||
className={`p-4 w-full ${MSG_MAX_W} self-center`}
|
||||
>
|
||||
<div key={`error-${message.nodeId}`} className="p-4">
|
||||
<ErrorBanner
|
||||
resubmit={onResubmit}
|
||||
error={error || loadError || ""}
|
||||
@@ -260,16 +164,6 @@ const ChatUI = React.memo(
|
||||
}
|
||||
|
||||
const previousMessage = i !== 0 ? messages[i - 1] : null;
|
||||
|
||||
// Check if this assistant message is part of a multi-model response
|
||||
// If so, skip rendering since it's already rendered in MultiModelResponseView
|
||||
if (
|
||||
previousMessage?.type === "user" &&
|
||||
getMultiModelResponses(previousMessage)
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const chatStateData = {
|
||||
agent: liveAgent,
|
||||
docs: message.documents ?? emptyDocs,
|
||||
@@ -283,7 +177,6 @@ const ChatUI = React.memo(
|
||||
<div
|
||||
id={`message-${message.nodeId}`}
|
||||
key={messageReactComponentKey}
|
||||
className={`w-full ${MSG_MAX_W} self-center`}
|
||||
>
|
||||
<AgentMessage
|
||||
rawPackets={message.packets}
|
||||
@@ -313,7 +206,7 @@ const ChatUI = React.memo(
|
||||
{(((error !== null || loadError !== null) &&
|
||||
messages[messages.length - 1]?.type === "user") ||
|
||||
messages[messages.length - 1]?.type === "error") && (
|
||||
<div className={`p-4 w-full ${MSG_MAX_W} self-center`}>
|
||||
<div className="p-4">
|
||||
<ErrorBanner
|
||||
resubmit={onResubmit}
|
||||
error={error || loadError || ""}
|
||||
|
||||
@@ -10,6 +10,7 @@ import React, {
|
||||
} from "react";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
|
||||
import LLMPopover from "@/refresh-components/popovers/LLMPopover";
|
||||
import { InputPrompt } from "@/app/app/interfaces";
|
||||
import { FilterManager, LlmManager, useFederatedConnectors } from "@/lib/hooks";
|
||||
import usePromptShortcuts from "@/hooks/usePromptShortcuts";
|
||||
@@ -19,7 +20,7 @@ import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { ChatState } from "@/app/app/interfaces";
|
||||
import { useForcedTools } from "@/lib/hooks/useForcedTools";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { cn, isImageFile } from "@/lib/utils";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import {
|
||||
@@ -422,6 +423,11 @@ const AppInputBar = React.memo(
|
||||
return currentMessageFiles.length > 1;
|
||||
}, [currentMessageFiles]);
|
||||
|
||||
const hasImageFiles = useMemo(
|
||||
() => currentMessageFiles.some((f) => isImageFile(f.name)),
|
||||
[currentMessageFiles]
|
||||
);
|
||||
|
||||
// Check if the agent has search tools available (internal search or web search)
|
||||
// AND if deep research is globally enabled in admin settings
|
||||
const showDeepResearch = useMemo(() => {
|
||||
@@ -609,6 +615,16 @@ const AppInputBar = React.memo(
|
||||
|
||||
{/* Bottom right controls */}
|
||||
<div className="flex flex-row items-center gap-1">
|
||||
<div
|
||||
data-testid="AppInputBar/llm-popover-trigger"
|
||||
className={cn(controlsLoading && "invisible")}
|
||||
>
|
||||
<LLMPopover
|
||||
llmManager={llmManager}
|
||||
requiresImageInput={hasImageFiles}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
{showMicButton &&
|
||||
(sttEnabled ? (
|
||||
<MicrophoneButton
|
||||
|
||||
Reference in New Issue
Block a user