Compare commits

..

3 Commits

Author SHA1 Message Date
Nik
bab95d8bf0 fix(chat): remove duplicate drain_done declaration after rebase 2026-03-31 20:02:29 -07:00
Nik
eb7bc74e1b fix(chat): persist LLM response on HTTP disconnect via drain_done + worker self-completion
When the HTTP client disconnects, Starlette throws GeneratorExit into the
drain loop generator. The old code called executor.shutdown(wait=False) with
no completion handling, leaving the assistant DB message as the TERMINATED
placeholder forever (regressing test_send_message_disconnect_and_cleanup).

New design:
- drain_done (threading.Event) signals emitters to return immediately instead
  of blocking on queue.put — no retry loops, no daemon threads
- One-time queue drain in the else block releases any in-progress puts so
  workers exit within milliseconds
- Workers self-complete: after run_llm_loop returns, each worker checks
  drain_done.is_set() and, if true, opens its own DB session and calls
  llm_loop_completion_handle directly

Unit test updated to reflect the async self-completion semantics: the test
blocks the worker inside run_llm_loop until gen.close() sets drain_done,
then waits for completion_called inside the patch context (while mocks are
still active) to avoid calling the real get_session_with_current_tenant.
2026-03-31 20:02:29 -07:00
Nik
29da0aefb5 feat(chat): add multi-model parallel streaming (N=2-3 LLMs side-by-side)
Adds support for running 2-3 LLMs in parallel within a single chat turn,
with responses streamed interleaved to the frontend via the merged queue
infrastructure introduced in the preceding PR.

Backend changes
- process_message.py: restore llm_overrides param on build_chat_turn and
  _stream_chat_turn; restore is_multi branching for LLM setup, context
  window sizing, and message ID reservation; add _build_model_display_name
  and handle_multi_model_stream (public multi-model entrypoint)
- db/chat.py: add reserve_multi_model_message_ids (reserves N assistant
  message placeholders sharing the same parent), set_preferred_response
  (marks one response as the user's preferred), and extend
  translate_db_message_to_chat_message_detail with preferred_response_id
  and model_display_name fields
- chat_backend.py: route requests with llm_overrides >1 through
  handle_multi_model_stream; reject non-streaming multi-model requests with
  OnyxError; add /set-preferred-response endpoint

Tests
- test_multi_model_streaming.py: unit tests for _run_models drain loop
  (arrival-order yield, error isolation, cancellation), handle_multi_model_stream
  validation guards, and N=1 backwards-compatibility
2026-03-31 20:01:21 -07:00
4 changed files with 1019 additions and 146 deletions

View File

@@ -66,6 +66,7 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import reserve_multi_model_message_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
@@ -95,12 +96,15 @@ from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.override_models import LLMOverride
from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import ModelResponseSlot
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
@@ -310,9 +314,7 @@ def extract_context_files(
return _empty_extracted_context_files()
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
max_actual_tokens = (
llm_max_context_window - reserved_token_count
) * max_llm_context_percentage
max_actual_tokens = (llm_max_context_window - reserved_token_count) * max_llm_context_percentage
if aggregate_tokens >= max_actual_tokens:
tool_metadata = []
@@ -502,6 +504,8 @@ def build_chat_turn(
new_msg_req: SendMessageRequest,
user: User,
db_session: Session,
# None → single-model (persona default LLM); non-empty list → multi-model (one LLM per override)
llm_overrides: list[LLMOverride] | None,
*,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
@@ -514,23 +518,26 @@ def build_chat_turn(
# NOTE: not stored in the database, only passed in to the LLM as context
additional_context: str | None = None,
) -> Generator[AnswerStreamPart, None, ChatTurnSetup]:
"""Setup generator for a single-model chat turn.
"""Shared setup generator for both single-model and multi-model chat turns.
Yields the packet(s) the frontend needs for request tracking, then returns an
immutable ``ChatTurnSetup`` containing everything the execution strategy needs.
Callers use::
setup = yield from build_chat_turn(new_msg_req, ...)
setup = yield from build_chat_turn(new_msg_req, ..., llm_overrides=...)
to forward yielded packets upstream while receiving the return value locally.
Args:
llm_overrides: ``None`` → single-model (persona default LLM).
Non-empty list → multi-model (one LLM per override).
"""
tenant_id = get_current_tenant_id()
is_multi = bool(llm_overrides)
user_id = user.id
llm_user_identifier = (
"anonymous_user" if user.is_anonymous else (user.email or str(user_id))
)
llm_user_identifier = "anonymous_user" if user.is_anonymous else (user.email or str(user_id))
# ── Session resolution ───────────────────────────────────────────────────
if not new_msg_req.chat_session_id:
@@ -559,9 +566,7 @@ def build_chat_turn(
persona = chat_session.persona
message_text = new_msg_req.message
user_identity = LLMUserIdentity(
user_id=llm_user_identifier, session_id=str(chat_session.id)
)
user_identity = LLMUserIdentity(user_id=llm_user_identifier, session_id=str(chat_session.id))
# Milestone tracking, most devs using the API don't need to understand this
mt_cloud_telemetry(
@@ -583,21 +588,42 @@ def build_chat_turn(
)
# Check LLM cost limits before using the LLM (only for Onyx-managed keys),
# then build the LLM instance.
primary_llm = get_llm_for_persona(
persona=persona,
user=user,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
)
check_llm_cost_limit_for_provider(
db_session=db_session,
tenant_id=tenant_id,
llm_provider_api_key=primary_llm.config.api_key,
)
llms = [primary_llm]
model_display_names = [""]
token_counter = get_llm_token_counter(primary_llm)
# then build the LLM instance(s).
if is_multi:
assert llm_overrides is not None
llms: list[LLM] = []
model_display_names: list[str] = []
for override in llm_overrides:
llm = get_llm_for_persona(
persona=persona,
user=user,
llm_override=override,
additional_headers=litellm_additional_headers,
)
check_llm_cost_limit_for_provider(
db_session=db_session,
tenant_id=tenant_id,
llm_provider_api_key=llm.config.api_key,
)
llms.append(llm)
model_display_names.append(_build_model_display_name(override))
# Use first LLM for token counting — model-agnostic for setup purposes
token_counter = get_llm_token_counter(llms[0])
else:
primary_llm = get_llm_for_persona(
persona=persona,
user=user,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
)
check_llm_cost_limit_for_provider(
db_session=db_session,
tenant_id=tenant_id,
llm_provider_api_key=primary_llm.config.api_key,
)
llms = [primary_llm]
model_display_names = [""]
token_counter = get_llm_token_counter(primary_llm)
# Verify that the user-specified files actually belong to the user
verify_user_files(
@@ -608,24 +634,17 @@ def build_chat_turn(
)
# Re-create linear history of messages
chat_history = create_chat_history_chain(
chat_session_id=chat_session.id, db_session=db_session
)
chat_history = create_chat_history_chain(chat_session_id=chat_session.id, db_session=db_session)
# Determine the parent message based on the request:
# - AUTO_PLACE_AFTER_LATEST_MESSAGE (-1): auto-place after latest message in chain
# - None or root ID: regeneration from root (first message)
# - positive int: place after that specific parent message
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
root_message = get_or_create_root_message(chat_session_id=chat_session.id, db_session=db_session)
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
parent_message = chat_history[-1] if chat_history else root_message
elif (
new_msg_req.parent_message_id is None
or new_msg_req.parent_message_id == root_message.id
):
elif new_msg_req.parent_message_id is None or new_msg_req.parent_message_id == root_message.id:
# Regeneration from root — clear history so we start fresh
parent_message = root_message
chat_history = []
@@ -639,9 +658,7 @@ def build_chat_turn(
break
if parent_message is None:
raise ValueError(
"The new message sent is not on the latest mainline of messages"
)
raise ValueError("The new message sent is not on the latest mainline of messages")
# ── Query Processing hook + user message ─────────────────────────────────
# Skipped on regeneration (parent is USER type): message already exists/was accepted.
@@ -666,9 +683,7 @@ def build_chat_turn(
).model_dump(),
response_type=QueryProcessingResponse,
)
message_text = _resolve_query_processing_hook_result(
hook_result, message_text
)
message_text = _resolve_query_processing_hook_result(hook_result, message_text)
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
@@ -730,16 +745,10 @@ def build_chat_turn(
# When use_memories is disabled, strip memories from the prompt context but keep
# user info/preferences. The full context is still passed to the LLM loop for
# memory tool persistence.
prompt_memory_context = (
user_memory_context
if user.use_memories
else user_memory_context.without_memories()
)
prompt_memory_context = user_memory_context if user.use_memories else user_memory_context.without_memories()
# ── Token reservation ────────────────────────────────────────────────────
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
custom_agent_prompt or ""
)
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (custom_agent_prompt or "")
reserved_token_count = calculate_reserved_tokens(
db_session=db_session,
persona_system_prompt=max_reserved_system_prompt_tokens_str,
@@ -758,7 +767,8 @@ def build_chat_turn(
db_session=db_session,
)
llm_max_context_window = llms[0].config.max_input_tokens
# For multi-model: use the smallest context window across all models for safety
llm_max_context_window = min(llm.config.max_input_tokens for llm in llms) if is_multi else llms[0].config.max_input_tokens
extracted_context_files = extract_context_files(
user_files=context_user_files,
@@ -783,9 +793,7 @@ def build_chat_turn(
all_tools = get_tools(db_session)
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
search_tool_id = next(
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID), None
)
search_tool_id = next((tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID), None)
forced_tool_id = new_msg_req.forced_tool_id
if (
@@ -802,24 +810,37 @@ def build_chat_turn(
# Convert loaded files to ChatFile format for tools like PythonTool
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
# ── Reserve assistant message ID → yield to frontend ─────────────────────
assistant_response = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session.id,
parent_message=user_message.id,
message_type=MessageType.ASSISTANT,
)
reserved_messages = [assistant_response]
yield MessageResponseIDInfo(
user_message_id=user_message.id,
reserved_assistant_message_id=assistant_response.id,
)
# ── Reserve assistant message ID(s) → yield to frontend ──────────────────
if is_multi:
assert llm_overrides is not None
reserved_messages = reserve_multi_model_message_ids(
db_session=db_session,
chat_session_id=chat_session.id,
parent_message_id=user_message.id,
model_display_names=model_display_names,
)
yield MultiModelMessageResponseIDInfo(
user_message_id=user_message.id,
responses=[
ModelResponseSlot(message_id=m.id, model_name=name) for m, name in zip(reserved_messages, model_display_names)
],
)
else:
assistant_response = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session.id,
parent_message=user_message.id,
message_type=MessageType.ASSISTANT,
)
reserved_messages = [assistant_response]
yield MessageResponseIDInfo(
user_message_id=user_message.id,
reserved_assistant_message_id=assistant_response.id,
)
# Convert the chat history into a simple format that is free of any DB objects
# and is easy to parse for the agent loop.
has_file_reader_tool = any(
tool.in_code_tool_id == FILE_READER_TOOL_ID for tool in persona.tools
)
has_file_reader_tool = any(tool.in_code_tool_id == FILE_READER_TOOL_ID for tool in persona.tools)
chat_history_result = convert_chat_history(
chat_history=chat_history,
@@ -848,9 +869,7 @@ def build_chat_turn(
all_injected_file_metadata.setdefault(fid, meta)
if all_injected_file_metadata:
logger.debug(
f"FileReader: file metadata for LLM: {[(fid, m.filename) for fid, m in all_injected_file_metadata.items()]}"
)
logger.debug(f"FileReader: file metadata for LLM: {[(fid, m.filename) for fid, m in all_injected_file_metadata.items()]}")
if summary_message is not None:
summary_simple = ChatMessageSimple(
@@ -955,19 +974,11 @@ def _run_models(
merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = queue.Queue()
state_containers: list[ChatStateContainer] = [
(
external_state_container
if (external_state_container is not None and i == 0)
else ChatStateContainer()
)
(external_state_container if (external_state_container is not None and i == 0) else ChatStateContainer())
for i in range(n_models)
]
model_succeeded: list[bool] = [False] * n_models
# Set when the drain loop exits early (HTTP disconnect / GeneratorExit).
# Signals emitters to skip future puts and workers to self-complete.
drain_done = threading.Event()
def _run_model(model_idx: int) -> None:
"""Run one LLM loop inside a worker thread, writing packets to ``merged_queue``."""
model_emitter = Emitter(
@@ -995,9 +1006,7 @@ def _run_models(
persona_id_filter=setup.search_params.persona_id_filter,
bypass_acl=setup.bypass_acl,
slack_context=setup.slack_context,
enable_slack_search=_should_enable_slack_search(
setup.persona, setup.new_msg_req.internal_search_filters
),
enable_slack_search=_should_enable_slack_search(setup.persona, setup.new_msg_req.internal_search_filters),
),
custom_tool_config=CustomToolConfig(
chat_session_id=setup.chat_session.id,
@@ -1012,25 +1021,15 @@ def _run_models(
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=setup.search_params.search_usage,
)
model_tools = [
tool
for tool_list in thread_tool_dict.values()
for tool in tool_list
]
model_tools = [tool for tool_list in thread_tool_dict.values() for tool in tool_list]
if setup.forced_tool_id and setup.forced_tool_id not in {
tool.id for tool in model_tools
}:
raise ValueError(
f"Forced tool {setup.forced_tool_id} not found in tools"
)
if setup.forced_tool_id and setup.forced_tool_id not in {tool.id for tool in model_tools}:
raise ValueError(f"Forced tool {setup.forced_tool_id} not found in tools")
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
if n_models == 1 and setup.new_msg_req.deep_research:
if setup.chat_session.project_id:
raise RuntimeError(
"Deep research is not supported for projects"
)
raise RuntimeError("Deep research is not supported for projects")
run_deep_research_llm_loop(
emitter=model_emitter,
state_container=sc,
@@ -1081,9 +1080,7 @@ def _run_models(
if drain_done.is_set() and model_succeeded[model_idx]:
try:
with get_session_with_current_tenant() as self_complete_db:
assistant_message = self_complete_db.get(
ChatMessage, setup.reserved_messages[model_idx].id
)
assistant_message = self_complete_db.get(ChatMessage, setup.reserved_messages[model_idx].id)
if assistant_message is not None:
llm_loop_completion_handle(
state_container=state_containers[model_idx],
@@ -1105,9 +1102,7 @@ def _run_models(
# Copy contextvars before submitting futures — ThreadPoolExecutor does NOT
# auto-propagate contextvars in Python 3.11; threads would inherit a blank context.
ctx = contextvars.copy_context()
executor = ThreadPoolExecutor(
max_workers=n_models, thread_name_prefix="multi-model"
)
executor = ThreadPoolExecutor(max_workers=n_models, thread_name_prefix="multi-model")
_completion_done: bool = False
try:
for i in range(n_models):
@@ -1129,18 +1124,14 @@ def _run_models(
llm_loop_completion_handle(
state_container=state_containers[i],
# partial captures model_succeeded[i] by value at loop time, not by reference
is_connected=functools.partial(
bool, model_succeeded[i]
),
is_connected=functools.partial(bool, model_succeeded[i]),
db_session=db_session,
assistant_message=setup.reserved_messages[i],
llm=setup.llms[i],
reserved_tokens=setup.reserved_token_count,
)
except Exception:
logger.exception(
f"Failed completion for model {i} on disconnect ({setup.model_display_names[i]})"
)
logger.exception(f"Failed completion for model {i} on disconnect ({setup.model_display_names[i]})")
yield Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
@@ -1158,17 +1149,11 @@ def _run_models(
# Do NOT decrement models_remaining — _run_model's finally always posts
# _MODEL_DONE, which is the sole completion signal.
error_msg = str(item)
stack_trace = "".join(
traceback.format_exception(type(item), item, item.__traceback__)
)
stack_trace = "".join(traceback.format_exception(type(item), item, item.__traceback__))
model_llm = setup.llms[model_idx]
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
error_msg = error_msg.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
stack_trace = stack_trace.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
error_msg = error_msg.replace(model_llm.config.api_key, "[REDACTED_API_KEY]")
stack_trace = stack_trace.replace(model_llm.config.api_key, "[REDACTED_API_KEY]")
yield StreamingError(
error=error_msg,
stack_trace=stack_trace,
@@ -1203,9 +1188,7 @@ def _run_models(
reserved_tokens=setup.reserved_token_count,
)
except Exception:
logger.exception(
f"Failed completion for model {i} ({setup.model_display_names[i]})"
)
logger.exception(f"Failed completion for model {i} ({setup.model_display_names[i]})")
_completion_done = True
finally:
@@ -1232,6 +1215,7 @@ def _stream_chat_turn(
new_msg_req: SendMessageRequest,
user: User,
db_session: Session,
llm_overrides: list[LLMOverride] | None = None,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
mcp_headers: dict[str, str] | None = None,
@@ -1240,19 +1224,23 @@ def _stream_chat_turn(
slack_context: SlackContext | None = None,
external_state_container: ChatStateContainer | None = None,
) -> AnswerStream:
"""Private implementation for single-model chat turn streaming.
"""Private implementation for single-model and multi-model chat turn streaming.
Builds the turn context via ``build_chat_turn``, then streams packets from
``_run_models`` back to the caller. Handles setup errors, LLM errors, and
cancellation uniformly, saving whatever partial state has been accumulated
before re-raising or yielding a terminal error packet.
Not called directly — use ``handle_stream_message_objects`` as the public entrypoint.
Not called directly — use the public wrappers:
- ``handle_stream_message_objects`` for single-model (N=1) requests.
- ``handle_multi_model_stream`` for side-by-side multi-model comparison (N>1).
Args:
new_msg_req: The incoming chat request from the user.
user: Authenticated user; may be anonymous for public personas.
db_session: Database session for this request.
llm_overrides: ``None`` → single-model (persona default LLM).
Non-empty list → multi-model (one LLM per override, 23 items).
litellm_additional_headers: Extra headers forwarded to the LLM provider.
custom_tool_additional_headers: Extra headers for custom tool HTTP calls.
mcp_headers: Extra headers for MCP tool calls.
@@ -1269,9 +1257,7 @@ def _stream_chat_turn(
followed by a terminal ``Packet`` containing ``OverallStop``.
"""
if new_msg_req.mock_llm_response is not None and not INTEGRATION_TESTS_MODE:
raise ValueError(
"mock_llm_response can only be used when INTEGRATION_TESTS_MODE=true"
)
raise ValueError("mock_llm_response can only be used when INTEGRATION_TESTS_MODE=true")
mock_response_token: Token[str | None] | None = None
setup: ChatTurnSetup | None = None
@@ -1281,6 +1267,7 @@ def _stream_chat_turn(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
@@ -1324,9 +1311,7 @@ def _stream_chat_turn(
except EmptyLLMResponseError as e:
stack_trace = traceback.format_exc()
logger.warning(
f"LLM returned an empty response (provider={e.provider}, model={e.model}, tool_choice={e.tool_choice})"
)
logger.warning(f"LLM returned an empty response (provider={e.provider}, model={e.model}, tool_choice={e.tool_choice})")
yield StreamingError(
error=e.client_error_msg,
stack_trace=stack_trace,
@@ -1346,16 +1331,10 @@ def _stream_chat_turn(
llm = setup.llms[0] if setup else None
if llm:
client_error_msg, error_code, is_retryable = litellm_exception_to_error_msg(
e, llm
)
client_error_msg, error_code, is_retryable = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
client_error_msg = client_error_msg.replace(
llm.config.api_key, "[REDACTED_API_KEY]"
)
stack_trace = stack_trace.replace(
llm.config.api_key, "[REDACTED_API_KEY]"
)
client_error_msg = client_error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]")
yield StreamingError(
error=client_error_msg,
stack_trace=stack_trace,
@@ -1401,11 +1380,12 @@ def handle_stream_message_objects(
slack_context: SlackContext | None = None,
external_state_container: ChatStateContainer | None = None,
) -> AnswerStream:
"""Single-model streaming entrypoint. Delegates to ``_stream_chat_turn``."""
"""Single-model streaming entrypoint. For multi-model comparison, use ``handle_multi_model_stream``."""
yield from _stream_chat_turn(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
llm_overrides=None,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
@@ -1416,6 +1396,70 @@ def handle_stream_message_objects(
)
def _build_model_display_name(override: LLMOverride) -> str:
"""Build a human-readable display name from an LLM override."""
if override.display_name:
return override.display_name
if override.model_version:
return override.model_version
if override.model_provider:
return override.model_provider
return "unknown"
def handle_multi_model_stream(
new_msg_req: SendMessageRequest,
user: User,
db_session: Session,
llm_overrides: list[LLMOverride],
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
mcp_headers: dict[str, str] | None = None,
) -> AnswerStream:
"""Thin wrapper for side-by-side multi-model comparison (23 models).
Validates the override list and delegates to ``_stream_chat_turn``,
which handles both single-model and multi-model execution via the same path.
Args:
new_msg_req: The incoming chat request. ``deep_research`` must be ``False``.
user: Authenticated user making the request.
db_session: Database session for this request.
llm_overrides: Exactly 2 or 3 ``LLMOverride`` objects — one per model to run.
litellm_additional_headers: Extra headers forwarded to each LLM provider.
custom_tool_additional_headers: Extra headers for custom tool HTTP calls.
mcp_headers: Extra headers for MCP tool calls.
Returns:
Generator yielding interleaved ``Packet`` objects from all models, each tagged
with ``model_index`` in its placement.
"""
n_models = len(llm_overrides)
if n_models < 2 or n_models > 3:
yield StreamingError(
error=f"Multi-model requires 2-3 overrides, got {n_models}",
error_code="VALIDATION_ERROR",
is_retryable=False,
)
return
if new_msg_req.deep_research:
yield StreamingError(
error="Multi-model is not supported with deep research",
error_code="VALIDATION_ERROR",
is_retryable=False,
)
return
yield from _stream_chat_turn(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
)
def llm_loop_completion_handle(
state_container: ChatStateContainer,
is_connected: Callable[[], bool],
@@ -1441,17 +1485,13 @@ def llm_loop_completion_handle(
completed_normally = is_connected()
if completed_normally:
if answer_tokens is None:
raise RuntimeError(
"LLM run completed normally but did not return an answer."
)
raise RuntimeError("LLM run completed normally but did not return an answer.")
final_answer = answer_tokens
else:
# Stopped by user - append stop message
logger.debug(f"Chat session {chat_session_id} stopped by user")
if answer_tokens:
final_answer = (
answer_tokens + " ... \n\nGeneration was stopped by the user."
)
final_answer = answer_tokens + " ... \n\nGeneration was stopped by the user."
else:
final_answer = "The generation was stopped by the user."

View File

@@ -617,6 +617,92 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
Also advances ``latest_child_message_id`` so the preferred response becomes
the active branch for any subsequent messages in the conversation.
Args:
db_session: Active database session.
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
preferred response is being set.
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
Raises:
ValueError: If either message is not found, if ``user_message_id`` does not
refer to a USER message, or if the assistant message is not a direct child
of the user message.
"""
user_msg = db_session.get(ChatMessage, user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
user_msg.latest_child_message_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -839,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_multi_model_stream
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -570,6 +575,46 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
# Narrowed here; is_multi_model already checked llm_overrides is not None
llm_overrides = chat_message_req.llm_overrides or []
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in handle_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
if is_multi_model and not chat_message_req.stream:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -660,6 +705,30 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
try:
# Ownership check: get_chat_message raises ValueError if the message
# doesn't belong to this user, preventing cross-user mutation.
get_chat_message(
chat_message_id=request_body.user_message_id,
user_id=user.id if user else None,
db_session=db_session,
)
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -0,0 +1,676 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in handle_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
import time
from collections.abc import Generator
from typing import Any
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from onyx.chat.models import StreamingError
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.utils.variable_functionality import global_version
@pytest.fixture(autouse=True)
def _restore_ee_version() -> Generator[None, None, None]:
"""Reset EE global state after each test.
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
(via the celery client import chain). Without this fixture, the EE flag stays
True for the rest of the session and breaks unrelated tests that mock Confluence
or other connectors and assume EE is disabled.
"""
original = global_version._is_ee
yield
global_version._is_ee = original
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
"""Return the first item yielded by handle_multi_model_stream."""
from onyx.chat.process_message import handle_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = handle_multi_model_stream(req, user, db, overrides)
return next(gen)
# ---------------------------------------------------------------------------
# handle_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_yields_error(self) -> None:
"""Exactly 1 override is not multi-model — yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [_make_override()])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_four_overrides_yields_error(self) -> None:
"""4 overrides exceeds maximum — yields StreamingError."""
req = _make_request()
result = _first_from_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_zero_overrides_yields_error(self) -> None:
"""Empty override list yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_deep_research_yields_error(self) -> None:
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
req = _make_request(deep_research=True)
result = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
assert isinstance(result, StreamingError)
assert "not supported" in result.error
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override yields error, 2 overrides passes validation."""
req = _make_request()
# 1 override must yield a StreamingError
result = _first_from_stream(req, [_make_override()])
assert isinstance(
result, StreamingError
), "1 override should yield StreamingError"
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
# missing session, that's OK — validation itself passed)
try:
result2 = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
if isinstance(result2, StreamingError) and "2-3" in result2.error:
pytest.fail(
f"2 overrides should pass validation, got StreamingError: {result2.error}"
)
except Exception:
pass # Any non-validation error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
assert user_msg.latest_child_message_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"
# ---------------------------------------------------------------------------
# _run_models — drain loop behaviour
# ---------------------------------------------------------------------------
def _make_setup(n_models: int = 1) -> MagicMock:
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
setup = MagicMock()
setup.llms = [MagicMock() for _ in range(n_models)]
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
setup.check_is_connected = MagicMock(return_value=True)
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
setup.reserved_token_count = 100
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
# constructors inside _run_model — must be typed correctly for Pydantic.
setup.new_msg_req.deep_research = False
setup.new_msg_req.internal_search_filters = None
setup.new_msg_req.allowed_tool_ids = None
setup.new_msg_req.include_citations = True
setup.search_params.project_id_filter = None
setup.search_params.persona_id_filter = None
setup.bypass_acl = False
setup.slack_context = None
setup.available_files.user_file_ids = []
setup.available_files.chat_file_ids = []
setup.forced_tool_id = None
setup.simple_chat_history = []
setup.chat_session.id = uuid4()
setup.user_message.id = None
setup.custom_tool_additional_headers = None
setup.mcp_headers = None
return setup
_RUN_MODELS_PATCHES = [
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch("onyx.chat.process_message.get_llm_token_counter", return_value=lambda _: 0),
]
def _run_models_collect(setup: MagicMock) -> list:
"""Drive _run_models to completion and return all yielded items."""
from onyx.chat.process_message import _run_models
return list(_run_models(setup, MagicMock(), MagicMock()))
class TestRunModels:
"""Tests for the _run_models worker-thread drain loop.
All external dependencies (LLM, DB, tools) are patched out. Worker threads
still run but return immediately since run_llm_loop is mocked.
"""
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
def emit_stop(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(
placement=Placement(turn_index=0),
obj=OverallStop(stop_reason="complete"),
)
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert len(stops) == 1
stop_obj = stops[0].obj
assert isinstance(stop_obj, OverallStop)
assert stop_obj.stop_reason == "complete"
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
def emit_one(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 0
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
def emit_one(**kwargs: Any) -> None:
# _model_idx is set by _run_model based on position in setup.llms
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 2
indices = {p.placement.model_index for p in reasoning}
assert indices == {0, 1}
def test_model_error_yields_streaming_error(self) -> None:
"""An exception inside a worker thread is surfaced as a StreamingError."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("intentional test failure")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
assert errors[0].error_code == "MODEL_ERROR"
assert "intentional test failure" in errors[0].error
def test_one_model_error_does_not_stop_other_models(self) -> None:
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
emitter = kwargs["emitter"]
# _model_idx is always int (0 for N=1, 0/1/2… for N>1)
if emitter._model_idx == 0:
raise RuntimeError("model 0 failed")
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=fail_model_0_succeed_model_1,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 1
def test_cancellation_yields_user_cancelled_stop(self) -> None:
"""If check_is_connected returns False, drain loop emits user_cancelled."""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
setup = _make_setup(n_models=1)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(setup)
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert any(
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
for s in stops
)
def test_completion_handle_called_on_disconnect(self) -> None:
"""llm_loop_completion_handle must still be called even when user disconnects.
Regression test for the disconnect-cleanup bug: the old
run_chat_loop_with_state_containers always called completion_callback in
its finally block (even on disconnect) so the DB message was updated from
the TERMINATED placeholder to a partial answer. The new _run_models must
replicate this — otherwise the integration test
test_send_message_disconnect_and_cleanup fails because the message stays
as "Response was terminated prior to completion, try regenerating."
"""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3)
setup = _make_setup(n_models=2)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
# Must be called once per model, not zero times
assert mock_handle.call_count == 2
def test_completion_handle_called_for_each_successful_model(self) -> None:
"""llm_loop_completion_handle must be called once per model that succeeded."""
setup = _make_setup(n_models=2)
with (
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
assert mock_handle.call_count == 2
def test_completion_handle_not_called_for_failed_model(self) -> None:
"""llm_loop_completion_handle must be skipped for a model that raised."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("fail")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(_make_setup(n_models=1))
mock_handle.assert_not_called()
def test_http_disconnect_completion_via_generator_exit(self) -> None:
"""GeneratorExit from HTTP disconnect triggers worker self-completion.
When the HTTP client closes the connection, Starlette throws GeneratorExit
into the stream generator. The finally block sets drain_done (signalling
emitters to stop blocking) and calls executor.shutdown(wait=False) so the
server thread is never blocked. Worker threads detect drain_done.is_set()
after run_llm_loop completes and self-persist the result via
llm_loop_completion_handle using their own DB session.
This is the primary regression for test_send_message_disconnect_and_cleanup:
the integration test disconnects mid-stream and expects the DB message to be
updated from the TERMINATED placeholder to the real response.
"""
import threading
# Signals the worker to unblock from run_llm_loop after gen.close() returns.
# This guarantees drain_done is set BEFORE the worker returns from run_llm_loop,
# so the self-completion path (drain_done.is_set() check) is always taken.
disconnect_received = threading.Event()
# Set by the llm_loop_completion_handle mock when called.
completion_called = threading.Event()
def emit_then_complete(**kwargs: Any) -> None:
"""Emit one packet (to give the drain loop a yield point), then block
until the main thread signals that gen.close() has been called. This
ensures drain_done is set before we return so model_succeeded is checked
against a set drain_done — no race condition.
"""
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
disconnect_received.wait(timeout=5)
setup = _make_setup(n_models=1)
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
setup.check_is_connected = MagicMock(return_value=True)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=emit_then_complete,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
from onyx.chat.process_message import _run_models
# cast to Generator so .close() is available; _run_models returns
# AnswerStream (= Iterator) but the actual object is always a generator.
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
# Advance to the first yielded packet — generator suspends at `yield item`.
first = next(gen)
assert isinstance(first, Packet)
# Simulate Starlette closing the stream on HTTP client disconnect.
# GeneratorExit is thrown at the `yield item` suspension point.
gen.close()
# Unblock the worker now that drain_done has been set by gen.close().
disconnect_received.set()
# Worker self-completes asynchronously (executor.shutdown(wait=False)).
# Wait here, inside the patch context, so that get_session_with_current_tenant
# and llm_loop_completion_handle mocks are still active when the worker calls them.
assert completion_called.wait(
timeout=5
), "worker must self-complete via drain_done within 5 seconds"
assert (
mock_handle.call_count == 1
), "completion handle must be called once for the successful model"
def test_external_state_container_used_for_model_zero(self) -> None:
"""When provided, external_state_container is used as state_containers[0]."""
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.process_message import _run_models
external = ChatStateContainer()
setup = _make_setup(n_models=1)
with (
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
list(
_run_models(
setup, MagicMock(), MagicMock(), external_state_container=external
)
)
# The state_container kwarg passed to run_llm_loop must be the external one
call_kwargs = mock_llm.call_args.kwargs
assert call_kwargs["state_container"] is external