Compare commits

...

12 Commits

Author SHA1 Message Date
Nik
f2ef7f2f83 fix(chat): use FILE_READER_TOOL_ID constant in has_file_reader_tool check
The check was comparing against the hardcoded string "file_reader" but
FILE_READER_TOOL_ID = "FileReaderTool". This meant has_file_reader_tool
was always False, silently disabling the forgotten-file injection
mechanism for every request routed through build_chat_turn.
2026-03-27 14:04:59 -07:00
Nik
cd561b21a2 fix(streaming): unblock executor shutdown on HTTP disconnect
Worker threads blocked indefinitely on merged_queue.put() when the
drain loop exited via GeneratorExit, causing executor.shutdown(wait=True)
to deadlock and the test to hang for the full CI timeout.

Fix: add a timeout to every merged_queue.put() call in the producer
path (Emitter.emit, _run_model exception handler, _run_model finally
sentinel). On queue.Full the thread discards the packet and exits
cleanly, allowing executor.shutdown(wait=True) to return and the
GeneratorExit finally block to call llm_loop_completion_handle.
2026-03-27 12:12:36 -07:00
Nik
7ca31368bb fix(chat): open fresh DB session in GeneratorExit cleanup to avoid stale session
When the HTTP client disconnects, CPython's reference counting fires
_run_models.close() only AFTER stream_generator's `with` block has
already closed its db_session. The previous fix used that stale session
in the finally block, causing llm_loop_completion_handle to fail
silently and leave the DB record as TERMINATED_RESPONSE_MESSAGE.

Fix: when _completion_done=False (HTTP disconnect path), open a fresh
get_session_with_current_tenant() session and re-load each ChatMessage
by ID before calling llm_loop_completion_handle.
2026-03-27 11:31:41 -07:00
Nik
6dcbde2c03 fix(tests): cast _run_models result to Generator to satisfy mypy 2026-03-27 10:24:31 -07:00
Nik
9df6b6183a fix(chat): wait for LLM threads and save on HTTP disconnect (GeneratorExit)
The previous fix only handled the stop-button path (is_connected()=False),
but test_send_message_disconnect_and_cleanup uses HTTP client disconnect,
which does NOT set the Redis stop fence. HTTP disconnect causes Starlette
to throw GeneratorExit into the generator, not a Redis fence signal.

Root cause: executor.shutdown(wait=False) in the finally block let LLM
threads run in the background without ever calling llm_loop_completion_handle,
leaving the DB message as the TERMINATED placeholder permanently.

Fix: track _completion_done in the outer try block. When the finally block
runs with _completion_done=False (GeneratorExit from HTTP disconnect, or an
unhandled exception), call executor.shutdown(wait=True) to wait for all LLM
threads to finish, then persist each successful model's response via
llm_loop_completion_handle. This mirrors the old run_chat_loop_with_state_containers
behavior: if is_connected(): wait_on_background(thread); completion_callback().

Add test_http_disconnect_completion_via_generator_exit which simulates the
exact Starlette disconnect mechanism: advance generator to a yield point,
call gen.close() (throws GeneratorExit), verify completion handle is called
and the LLM thread completed before gen.close() returns.
2026-03-27 01:52:44 -07:00
Nik
8fb7cd6189 fix(chat): call completion handler on disconnect to clear TERMINATED placeholder
When the client disconnects mid-stream, _run_models returns early from the
drain loop before the llm_loop_completion_handle loop is reached, leaving
the reserved DB message stuck as "Response was terminated prior to
completion, try regenerating."

The old run_chat_loop_with_state_containers always called completion_callback
in its finally block (even on disconnect) with whatever state had accumulated.
This fix replicates that: before returning early on disconnect, we call
llm_loop_completion_handle for all models. The handler detects
is_connected()=False and writes "The generation was stopped by the user."
(or a partial answer) instead of the TERMINATED placeholder.

Also adds regression test: test_completion_handle_called_on_disconnect.
2026-03-27 00:27:30 -07:00
Nik
2724c61c95 fix(chat): fix Emitter API callers, drain loop mypy errors, and redundant OverallStop
- Remove terminal OverallStop from _run_models; run_llm_loop already emits one
  via Emitter — the extra packet was breaking existing streaming tests that
  expected StopIteration after consuming the LLM loop's OverallStop
- Restructure drain loop with try/else to fix mypy possibly-undefined errors
  on item/model_idx (variables assigned in try, used after except/continue)
- Fix research_agent.py and test_memory_tool_packets.py: old Emitter(bus)
  positional API no longer valid; use Emitter() and read from emitter.bus
- Fix test union-attr errors: narrow Packet.obj to OverallStop before
  accessing stop_reason
2026-03-26 23:45:30 -07:00
Nik
a23ee85039 docs(chat): add docstrings, unit tests, and Emitter refactor for multi-model streaming
- Refactor _run_models to use a single unified execution path for N=1 and N>1
  via ThreadPoolExecutor + merged_queue drain loop
- Introduce Emitter class with standalone/streaming modes and model_index tagging
- Add comprehensive docstrings to Emitter, _run_models, _run_model, handle_multi_model_stream,
  handle_stream_message_objects, Placement, LLMOverride, and set_preferred_response
- Fix stale comment referencing deleted _ModelIndexEmitter class
- Add 21 unit tests for Emitter (test_emitter.py)
- Add 9 unit tests for _run_models drain loop (test_multi_model_streaming.py)
  covering: N=1 backwards-compat, N>1 model_index tagging, error isolation,
  cancellation, completion handles, external_state_container routing
2026-03-26 23:07:02 -07:00
Nik
d4d0f3c612 restore arg comments to build_chat_turn signature 2026-03-26 18:19:37 -07:00
Nik
8e1ad517e9 restore meaningful inline comments lost during build_chat_turn extraction 2026-03-26 18:14:45 -07:00
Nik
4a9c8b6fbf refactor(chat): phase-based multi-model architecture for process_message
Extracts ~280 lines of duplicated setup code from handle_stream_message_objects
and run_multi_model_stream into a shared build_chat_turn() generator that
returns a ChatTurnSetup frozen dataclass.

New structure:
- ChatTurnSetup: immutable context passed between setup and execution phases
- build_chat_turn(): shared setup generator (session, LLM, history, files,
  tokens, search params, message reservation) — runs for both paths
- _run_single_model(): execution strategy for single-LLM path
- _run_multi_model(): execution strategy for parallel-LLM path (bounded
  queue, contextvars.copy_context() per thread, per-thread history copy)
- handle_stream_message_objects: thin orchestrator (~40 lines vs 630)
- handle_multi_model_stream: new thin orchestrator replacing run_multi_model_stream

Bug fixes:
- Query Processing hook now runs in build_chat_turn for both paths
  (was missing from multi-model path)
- contextvars.copy_context() + ctx.run() used for thread submission
  (threads now inherit tenant context correctly)
- queue.Queue(maxsize=100) bounded for backpressure; timeout=0.05s (50ms)
- list(simple_chat_history) per-thread copy (run_llm_loop mutates the list)
2026-03-26 17:56:06 -07:00
Nik
a49edf3e18 feat(chat): add multi-model parallel streaming backend 2026-03-25 20:17:38 -07:00
11 changed files with 2148 additions and 641 deletions

View File

@@ -1,19 +1,8 @@
import threading
import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.tools.models import ToolCallInfo
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
# Type alias for search doc deduplication key
# Simple key: just document_id (str)
@@ -159,114 +148,3 @@ class ChatStateContainer:
"""Thread-safe getter for emitted citations (returns a copy)."""
with self._lock:
return self._emitted_citations.copy()
def run_chat_loop_with_state_containers(
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
with event streaming capabilities.
The wrapped function should accept emitter as first arg and use it to emit
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
)
for packet in packets:
# Process packets
pass
"""
def run_with_exception_capture() -> None:
try:
chat_loop_func(emitter, state_container)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(
Packet(
placement=Placement(turn_index=0),
obj=PacketException(type="error", exception=e),
)
)
# Run the function in a background thread
thread = run_in_background(run_with_exception_capture)
pkt: Packet | None = None
last_turn_index = 0 # Track the highest turn_index seen for stop packet
last_cancel_check = time.monotonic()
cancel_check_interval = 0.3 # Check for cancellation every 300ms
try:
while True:
# Poll queue with 300ms timeout for natural stop signal checking
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
try:
pkt = emitter.bus.get(timeout=0.3)
except Empty:
if not is_connected():
# Stop signal detected
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = time.monotonic()
continue
if pkt is not None:
# Track the highest turn_index for the stop packet
if pkt.placement and pkt.placement.turn_index > last_turn_index:
last_turn_index = pkt.placement.turn_index
if isinstance(pkt.obj, OverallStop):
yield pkt
break
elif isinstance(pkt.obj, PacketException):
raise pkt.obj.exception
else:
yield pkt
# Check for cancellation periodically even when packets are flowing
# This ensures stop signal is checked during active streaming
current_time = time.monotonic()
if current_time - last_cancel_check >= cancel_check_interval:
if not is_connected():
# Stop signal detected during streaming
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = current_time
finally:
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
# Skip waiting if user disconnected to exit quickly.
if is_connected():
wait_on_background(thread)
try:
completion_callback(state_container)
except Exception as e:
emitter.emit(
Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=PacketException(type="error", exception=e),
)
)

View File

@@ -1,19 +1,84 @@
import queue
from queue import Queue
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
class Emitter:
"""Use this inside tools to emit arbitrary UI progress."""
"""Routes packets produced during tool and LLM execution to the right destination.
def __init__(self, bus: Queue):
self.bus = bus
Operates in one of two modes determined by whether ``merged_queue`` is supplied:
**Standalone** (no ``merged_queue``): packets land on ``self.bus``. Used by tests,
custom tools, and any caller that reads the emitter directly after execution.
**Streaming** (``merged_queue`` provided): packets are tagged with ``model_index``
and placed as ``(key, packet)`` tuples on the shared queue for the
``_run_models`` drain loop to consume and yield downstream.
Attributes:
bus: Fallback queue for standalone mode. Always created so existing callers
(tests, eval harnesses, custom-tool scripts) work without modification.
Args:
model_idx: Index embedded in packet placements. Pass ``None`` for single-model
runs to preserve the backwards-compatible wire format (``model_index=None``
in the packet); pass an integer for each model in a multi-model run.
merged_queue: Shared queue owned by the ``_run_models`` drain loop. When set,
all ``emit()`` calls route here instead of ``self.bus``.
Example::
# Standalone — read from bus after the fact (tests, evals)
emitter = Emitter()
emitter.emit(packet)
result = emitter.bus.get()
# Streaming — wired into _run_models (production path)
emitter = Emitter(model_idx=0, merged_queue=merged_queue)
emitter.emit(packet) # places (0, tagged_packet) on merged_queue
"""
def __init__(
self,
model_idx: int | None = None,
merged_queue: "queue.Queue | None" = None,
) -> None:
self._model_idx = model_idx
self._merged_queue = merged_queue
# Always created for backwards compatibility (tests, custom_tool, customer scripts, etc.)
self.bus: Queue[Packet] = Queue()
def emit(self, packet: Packet) -> None:
self.bus.put(packet) # Thread-safe
"""Emit a packet, routing it to the merged queue or the local bus.
In streaming mode, stamps the packet's placement with ``model_index`` before
forwarding so the drain loop can attribute it to the correct model. In
standalone mode, places the packet on ``self.bus`` unchanged.
Args:
packet: The packet to emit.
"""
if self._merged_queue is not None:
tagged_placement = Placement(
turn_index=packet.placement.turn_index if packet.placement else 0,
tab_index=packet.placement.tab_index if packet.placement else 0,
sub_turn_index=(
packet.placement.sub_turn_index if packet.placement else None
),
model_index=self._model_idx,
)
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
key = self._model_idx if self._model_idx is not None else 0
try:
self._merged_queue.put((key, tagged_packet), timeout=1.0)
except queue.Full:
# Drain loop is gone (e.g. GeneratorExit on disconnect); discard packet.
pass
else:
self.bus.put(packet)
def get_default_emitter() -> Emitter:
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
return emitter
return Emitter()

File diff suppressed because it is too large Load Diff

View File

@@ -617,6 +617,92 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
Also advances ``latest_child_message_id`` so the preferred response becomes
the active branch for any subsequent messages in the conversation.
Args:
db_session: Active database session.
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
preferred response is being set.
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
Raises:
ValueError: If either message is not found, if ``user_message_id`` does not
refer to a USER message, or if the assistant message is not a direct child
of the user message.
"""
user_msg = db_session.get(ChatMessage, user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
user_msg.latest_child_message_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -839,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -8,6 +8,24 @@ from pydantic import BaseModel
class LLMOverride(BaseModel):
"""Per-request LLM settings that override persona defaults.
All fields are optional — only the fields that differ from the persona's
configured LLM need to be supplied. Used both over the wire (API requests)
and for multi-model comparison, where one override is supplied per model.
Attributes:
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
When ``None``, the persona's default provider is used.
model_version: Specific model version string (e.g. ``"gpt-4o"``).
When ``None``, the persona's default model is used.
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
persona's default temperature is used.
display_name: Human-readable label shown in the UI for this model,
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
when not set.
"""
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None

View File

@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_multi_model_stream
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -570,6 +575,46 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
# Narrowed here; is_multi_model already checked llm_overrides is not None
llm_overrides = chat_message_req.llm_overrides or []
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in handle_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
if is_multi_model and not chat_message_req.stream:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -660,6 +705,30 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
try:
# Ownership check: get_chat_message raises ValueError if the message
# doesn't belong to this user, preventing cross-user mutation.
get_chat_message(
chat_message_id=request_body.user_message_id,
user_id=user.id if user else None,
db_session=db_session,
)
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -2,11 +2,24 @@ from pydantic import BaseModel
class Placement(BaseModel):
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
"""Coordinates that identify where a streaming packet belongs in the UI.
The frontend uses these fields to route each packet to the correct turn,
tool tab, agent sub-turn, and (in multi-model mode) response column.
Attributes:
turn_index: Monotonically increasing index of the iterative reasoning block
(e.g. tool call round) within this chat message. Lower values happened first.
tab_index: Disambiguates parallel tool calls within the same turn so each
tool's output can be displayed in its own tab.
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
top-level packets; an integer for tool-within-tool output.
model_index: Which model this packet belongs to in a multi-model comparison
(0, 1, or 2). ``None`` for single-model responses, preserving the
backwards-compatible wire format for existing API consumers.
"""
turn_index: int
# For parallel tool calls to preserve order of execution
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -708,7 +708,6 @@ def run_research_agent_calls(
if __name__ == "__main__":
from queue import Queue
from uuid import uuid4
from onyx.chat.chat_state import ChatStateContainer
@@ -744,8 +743,7 @@ if __name__ == "__main__":
if user is None:
raise ValueError("No users found in database. Please create a user first.")
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
emitter = Emitter()
state_container = ChatStateContainer()
tool_dict = construct_tools(
@@ -792,4 +790,4 @@ if __name__ == "__main__":
print(result.intermediate_report)
print("=" * 80)
print(f"Citations: {result.citation_mapping}")
print(f"Total packets emitted: {bus.qsize()}")
print(f"Total packets emitted: {emitter.bus.qsize()}")

View File

@@ -0,0 +1,253 @@
"""Unit tests for the Emitter class.
Covers both modes (standalone and streaming) without any real database,
LLM, or queue infrastructure beyond the stdlib Queue.
"""
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _placement(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
def _packet(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Packet:
"""Build a minimal valid packet with an OverallStop payload."""
return Packet(
placement=_placement(turn_index, tab_index, sub_turn_index),
obj=OverallStop(stop_reason="test"),
)
# ---------------------------------------------------------------------------
# Standalone mode (no merged_queue)
# ---------------------------------------------------------------------------
class TestEmitterStandaloneMode:
def test_emitted_packet_arrives_on_bus(self) -> None:
emitter = Emitter()
pkt = _packet()
emitter.emit(pkt)
assert emitter.bus.get_nowait() is pkt
def test_bus_is_empty_before_emit(self) -> None:
emitter = Emitter()
assert emitter.bus.empty()
def test_multiple_packets_delivered_fifo(self) -> None:
emitter = Emitter()
p1 = _packet(turn_index=0)
p2 = _packet(turn_index=1)
emitter.emit(p1)
emitter.emit(p2)
assert emitter.bus.get_nowait() is p1
assert emitter.bus.get_nowait() is p2
def test_packet_not_modified(self) -> None:
"""Standalone mode must not wrap or mutate the packet."""
emitter = Emitter()
pkt = _packet(turn_index=7, tab_index=3)
emitter.emit(pkt)
retrieved = emitter.bus.get_nowait()
assert retrieved.placement.turn_index == 7
assert retrieved.placement.tab_index == 3
def test_get_default_emitter_is_standalone(self) -> None:
emitter = get_default_emitter()
pkt = _packet()
emitter.emit(pkt)
# Packet lands on the bus, not a shared queue
assert emitter.bus.get_nowait() is pkt
# ---------------------------------------------------------------------------
# Streaming mode (merged_queue provided)
# ---------------------------------------------------------------------------
class TestEmitterStreamingMode:
# --- Queue routing ---
def test_packet_goes_to_merged_queue_not_bus(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet())
assert not mq.empty()
assert emitter.bus.empty()
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=1, merged_queue=mq)
emitter.emit(_packet())
item = mq.get_nowait()
assert isinstance(item, tuple)
assert len(item) == 2
# --- model_index tagging ---
def test_model_idx_none_preserves_model_index_none(self) -> None:
"""N=1 backwards-compat: model_index must stay None in the packet."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=None, merged_queue=mq)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index is None
def test_model_idx_zero_tags_packet(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 0
def test_model_idx_one_tags_packet(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=1, merged_queue=mq)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 1
def test_model_idx_two_tags_packet(self) -> None:
"""Boundary: third model in a 3-model run."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=2, merged_queue=mq)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 2
# --- Queue key ---
def test_key_equals_model_idx_when_set(self) -> None:
"""Drain loop uses the key to route packets; it must match model_idx."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=2, merged_queue=mq)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 2
def test_key_is_zero_when_model_idx_none(self) -> None:
"""N=1: key defaults to 0 (single slot in the drain loop)."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=None, merged_queue=mq)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 0
# --- Placement field preservation ---
def test_turn_index_is_preserved(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet(turn_index=5))
_, tagged = mq.get_nowait()
assert tagged.placement.turn_index == 5
def test_tab_index_is_preserved(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet(tab_index=3))
_, tagged = mq.get_nowait()
assert tagged.placement.tab_index == 3
def test_sub_turn_index_is_preserved(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet(sub_turn_index=2))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index == 2
def test_sub_turn_index_none_is_preserved(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet(sub_turn_index=None))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index is None
def test_packet_obj_is_not_modified(self) -> None:
"""The payload object must survive tagging untouched."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
original_obj = OverallStop(stop_reason="sentinel")
pkt = Packet(placement=_placement(), obj=original_obj)
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert tagged.obj is original_obj
def test_different_obj_types_are_handled(self) -> None:
"""Any valid PacketObj type passes through correctly."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
pkt = Packet(placement=_placement(), obj=ReasoningStart())
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert isinstance(tagged.obj, ReasoningStart)
# --- bus is always created ---
def test_bus_exists_in_streaming_mode(self) -> None:
"""bus must always be present for backwards-compat with existing callers."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
assert hasattr(emitter, "bus")
assert isinstance(emitter.bus, queue.Queue)
def test_bus_stays_empty_in_streaming_mode(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet())
assert emitter.bus.empty()

View File

@@ -0,0 +1,640 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in handle_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
import time
from collections.abc import Generator
from typing import Any
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from onyx.chat.models import StreamingError
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
"""Advance the generator one step to trigger early validation."""
from onyx.chat.process_message import handle_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = handle_multi_model_stream(req, user, db, overrides)
# Calling next() executes until the first yield OR raises.
# Validation errors are raised before any yield.
next(gen)
# ---------------------------------------------------------------------------
# handle_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_raises(self) -> None:
"""Exactly 1 override is not multi-model — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
def test_four_overrides_raises(self) -> None:
"""4 overrides exceeds maximum — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
def test_zero_overrides_raises(self) -> None:
"""Empty override list raises."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [])
def test_deep_research_raises(self) -> None:
"""deep_research=True is incompatible with multi-model."""
req = _make_request(deep_research=True)
with pytest.raises(ValueError, match="not supported"):
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
req = _make_request()
# 1 override must fail
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
try:
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
except ValueError as exc:
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
except Exception:
pass # Any other error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
assert user_msg.latest_child_message_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"
# ---------------------------------------------------------------------------
# _run_models — drain loop behaviour
# ---------------------------------------------------------------------------
def _make_setup(n_models: int = 1) -> MagicMock:
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
setup = MagicMock()
setup.llms = [MagicMock() for _ in range(n_models)]
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
setup.check_is_connected = MagicMock(return_value=True)
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
setup.reserved_token_count = 100
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
# constructors inside _run_model — must be typed correctly for Pydantic.
setup.new_msg_req.deep_research = False
setup.new_msg_req.internal_search_filters = None
setup.new_msg_req.allowed_tool_ids = None
setup.new_msg_req.include_citations = True
setup.search_params.project_id_filter = None
setup.search_params.persona_id_filter = None
setup.bypass_acl = False
setup.slack_context = None
setup.available_files.user_file_ids = []
setup.available_files.chat_file_ids = []
setup.forced_tool_id = None
setup.simple_chat_history = []
setup.chat_session.id = uuid4()
setup.user_message.id = None
setup.custom_tool_additional_headers = None
setup.mcp_headers = None
return setup
_RUN_MODELS_PATCHES = [
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch("onyx.chat.process_message.get_llm_token_counter", return_value=lambda _: 0),
]
def _run_models_collect(setup: MagicMock) -> list:
"""Drive _run_models to completion and return all yielded items."""
from onyx.chat.process_message import _run_models
return list(_run_models(setup, MagicMock(), MagicMock()))
class TestRunModels:
"""Tests for the _run_models worker-thread drain loop.
All external dependencies (LLM, DB, tools) are patched out. Worker threads
still run but return immediately since run_llm_loop is mocked.
"""
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
def emit_stop(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(
placement=Placement(turn_index=0),
obj=OverallStop(stop_reason="complete"),
)
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert len(stops) == 1
stop_obj = stops[0].obj
assert isinstance(stop_obj, OverallStop)
assert stop_obj.stop_reason == "complete"
def test_n1_emitted_packet_has_model_index_none(self) -> None:
"""Single-model path: model_index stays None for wire backwards-compat."""
def emit_one(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index is None
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
def emit_one(**kwargs: Any) -> None:
# _model_idx is set by _run_model based on position in setup.llms
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 2
indices = {p.placement.model_index for p in reasoning}
assert indices == {0, 1}
def test_model_error_yields_streaming_error(self) -> None:
"""An exception inside a worker thread is surfaced as a StreamingError."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("intentional test failure")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
assert errors[0].error_code == "MODEL_ERROR"
assert "intentional test failure" in errors[0].error
def test_one_model_error_does_not_stop_other_models(self) -> None:
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
emitter = kwargs["emitter"]
# _model_idx is None for N=1, int for N>1
if emitter._model_idx == 0:
raise RuntimeError("model 0 failed")
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=fail_model_0_succeed_model_1,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 1
def test_cancellation_yields_user_cancelled_stop(self) -> None:
"""If check_is_connected returns False, drain loop emits user_cancelled."""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
setup = _make_setup(n_models=1)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(setup)
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert any(
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
for s in stops
)
def test_completion_handle_called_on_disconnect(self) -> None:
"""llm_loop_completion_handle must still be called even when user disconnects.
Regression test for the disconnect-cleanup bug: the old
run_chat_loop_with_state_containers always called completion_callback in
its finally block (even on disconnect) so the DB message was updated from
the TERMINATED placeholder to a partial answer. The new _run_models must
replicate this — otherwise the integration test
test_send_message_disconnect_and_cleanup fails because the message stays
as "Response was terminated prior to completion, try regenerating."
"""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3)
setup = _make_setup(n_models=2)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
# Must be called once per model, not zero times
assert mock_handle.call_count == 2
def test_completion_handle_called_for_each_successful_model(self) -> None:
"""llm_loop_completion_handle must be called once per model that succeeded."""
setup = _make_setup(n_models=2)
with (
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
assert mock_handle.call_count == 2
def test_completion_handle_not_called_for_failed_model(self) -> None:
"""llm_loop_completion_handle must be skipped for a model that raised."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("fail")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(_make_setup(n_models=1))
mock_handle.assert_not_called()
def test_http_disconnect_completion_via_generator_exit(self) -> None:
"""GeneratorExit from HTTP disconnect triggers wait+completion in finally.
When the HTTP client closes the connection, Starlette throws GeneratorExit
into the stream generator, which propagates into _run_models. The finally
block must call executor.shutdown(wait=True) to wait for LLM threads to
finish, then persist their results via llm_loop_completion_handle.
This is the primary regression for test_send_message_disconnect_and_cleanup:
the integration test disconnects mid-stream and expects the DB message to be
updated from the TERMINATED placeholder to the real response.
"""
import threading
thread_completed = threading.Event()
def emit_then_complete(**kwargs: Any) -> None:
"""Emit one packet (to give generator a yield point), then finish."""
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
# Small sleep so executor.shutdown(wait=True) in finally actually waits.
time.sleep(0.05)
thread_completed.set()
setup = _make_setup(n_models=1)
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
setup.check_is_connected = MagicMock(return_value=True)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=emit_then_complete,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
from onyx.chat.process_message import _run_models
# cast to Generator so .close() is available; _run_models returns
# AnswerStream (= Iterator) but the actual object is always a generator.
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
# Advance to the first yielded packet — generator suspends at `yield item`.
first = next(gen)
assert isinstance(first, Packet)
# Simulate Starlette closing the stream on HTTP client disconnect.
# GeneratorExit is thrown at the `yield item` suspension point.
gen.close()
# Finally block must have waited for the thread and saved completion.
assert (
thread_completed.is_set()
), "LLM thread must complete before gen.close() returns"
assert (
mock_handle.call_count == 1
), "completion handle must be called for the successful model"
def test_external_state_container_used_for_model_zero(self) -> None:
"""When provided, external_state_container is used as state_containers[0]."""
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.process_message import _run_models
external = ChatStateContainer()
setup = _make_setup(n_models=1)
with (
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
list(
_run_models(
setup, MagicMock(), MagicMock(), external_state_container=external
)
)
# The state_container kwarg passed to run_llm_loop must be the external one
call_kwargs = mock_llm.call_args.kwargs
assert call_kwargs["state_container"] is external

View File

@@ -1,6 +1,5 @@
"""Tests for memory tool streaming packet emissions."""
from queue import Queue
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -19,8 +18,7 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
@pytest.fixture
def emitter() -> Emitter:
bus: Queue = Queue()
return Emitter(bus)
return Emitter()
@pytest.fixture