refactor(chat): unify N=1/N>1 path, simplify Emitter, tighten timeouts

- handle_stream_message_objects gains llm_overrides param; handle_multi_model_stream
  becomes a thin validation wrapper that delegates to it (Evan's unification request)
- Emitter.model_idx defaults to 0 instead of None; removes None-branch in emit()
- All merged_queue.put() timeouts unified at 3s (was 1s in Emitter, 5s in _run_model)
- Add comment warning against shared DB state writes inside _run_model worker threads
- Update unit tests to reflect model_index=0 for N=1 and StreamingError validation shape
This commit is contained in:
Nik
2026-03-30 11:32:31 -07:00
parent f2ef7f2f83
commit 8b37e5c775
4 changed files with 92 additions and 118 deletions

View File

@@ -22,9 +22,8 @@ class Emitter:
(tests, eval harnesses, custom-tool scripts) work without modification.
Args:
model_idx: Index embedded in packet placements. Pass ``None`` for single-model
runs to preserve the backwards-compatible wire format (``model_index=None``
in the packet); pass an integer for each model in a multi-model run.
model_idx: Index embedded in packet placements. Defaults to ``0`` for N=1
runs. Each model in a multi-model run receives its own index (0, 1, 2…).
merged_queue: Shared queue owned by the ``_run_models`` drain loop. When set,
all ``emit()`` calls route here instead of ``self.bus``.
@@ -42,7 +41,7 @@ class Emitter:
def __init__(
self,
model_idx: int | None = None,
model_idx: int = 0,
merged_queue: "queue.Queue | None" = None,
) -> None:
self._model_idx = model_idx
@@ -70,9 +69,9 @@ class Emitter:
model_index=self._model_idx,
)
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
key = self._model_idx if self._model_idx is not None else 0
key = self._model_idx
try:
self._merged_queue.put((key, tagged_packet), timeout=1.0)
self._merged_queue.put((key, tagged_packet), timeout=3.0)
except queue.Full:
# Drain loop is gone (e.g. GeneratorExit on disconnect); discard packet.
pass

View File

@@ -1021,16 +1021,17 @@ def _run_models(
model_idx: Zero-based index of the model in ``setup.llms``. Determines
which LLM and state container this thread operates on.
"""
# N=1: model_idx=None keeps model_index as None in packets (backwards compat).
# N>1: model_idx=int tags packets with the model's index.
model_emitter = Emitter(
model_idx=model_idx if n_models > 1 else None,
model_idx=model_idx,
merged_queue=merged_queue,
)
sc = state_containers[model_idx]
model_llm = setup.llms[model_idx]
try:
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
# Do NOT write to the outer db_session (or any shared DB state) from here;
# all DB writes in this thread must go through thread_db_session.
with get_session_with_current_tenant() as thread_db_session:
thread_tool_dict = construct_tools(
persona=setup.persona,
@@ -1118,13 +1119,13 @@ def _run_models(
except Exception as e:
try:
merged_queue.put((model_idx, e), timeout=5.0)
merged_queue.put((model_idx, e), timeout=3.0)
except queue.Full:
pass # Drain loop gone (GeneratorExit); thread exits cleanly
finally:
try:
merged_queue.put((model_idx, _MODEL_DONE), timeout=5.0)
merged_queue.put((model_idx, _MODEL_DONE), timeout=3.0)
except queue.Full:
pass # Drain loop gone (GeneratorExit); thread exits cleanly
@@ -1290,8 +1291,7 @@ def handle_stream_message_objects(
new_msg_req: SendMessageRequest,
user: User,
db_session: Session,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
llm_overrides: list[LLMOverride] | None = None,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
mcp_headers: dict[str, str] | None = None,
@@ -1306,7 +1306,7 @@ def handle_stream_message_objects(
# Optional external state container for non-streaming access to accumulated state
external_state_container: ChatStateContainer | None = None,
) -> AnswerStream:
"""Primary entrypoint for a single-model chat turn.
"""Primary entrypoint for single-model and multi-model chat turns.
Builds the turn context via ``build_chat_turn``, then streams packets from
``_run_models`` back to the caller. Handles setup errors, LLM errors, and
@@ -1317,6 +1317,8 @@ def handle_stream_message_objects(
new_msg_req: The incoming chat request from the user.
user: Authenticated user; may be anonymous for public personas.
db_session: Database session for this request.
llm_overrides: ``None`` → single-model (persona default LLM).
Non-empty list → multi-model (one LLM per override, 23 items).
litellm_additional_headers: Extra headers forwarded to the LLM provider.
custom_tool_additional_headers: Extra headers for custom tool HTTP calls.
mcp_headers: Extra headers for MCP tool calls.
@@ -1345,7 +1347,7 @@ def handle_stream_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
llm_overrides=None,
llm_overrides=llm_overrides,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
@@ -1478,13 +1480,10 @@ def handle_multi_model_stream(
custom_tool_additional_headers: dict[str, str] | None = None,
mcp_headers: dict[str, str] | None = None,
) -> AnswerStream:
"""Entrypoint for side-by-side multi-model comparison (23 models).
"""Thin wrapper for side-by-side multi-model comparison (23 models).
Validates the override list, builds the turn context via ``build_chat_turn``
(which reserves one ``ChatMessage`` row per model), then streams interleaved
packets from ``_run_models`` back to the caller. Each packet carries a
``model_index`` in its placement so the frontend can route it to the correct
response column.
Validates the override list and delegates to ``handle_stream_message_objects``,
which handles both single-model and multi-model execution via the same path.
Args:
new_msg_req: The incoming chat request. ``deep_research`` must be ``False``.
@@ -1497,63 +1496,32 @@ def handle_multi_model_stream(
Returns:
Generator yielding interleaved ``Packet`` objects from all models, each tagged
with ``model_index`` in its placement. Terminates with one
``ChatStateContainerPacket`` per model.
Raises:
ValueError: (yielded as ``StreamingError``) if ``llm_overrides`` is not 23
items, or if ``deep_research`` is ``True``.
with ``model_index`` in its placement.
"""
n_models = len(llm_overrides)
if n_models < 2 or n_models > 3:
raise ValueError(f"Multi-model requires 2-3 overrides, got {n_models}")
if new_msg_req.deep_research:
raise ValueError("Multi-model is not supported with deep research")
setup: ChatTurnSetup | None = None
try:
setup = yield from build_chat_turn(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
)
yield from _run_models(setup, user, db_session)
except ValueError as e:
logger.exception("Failed to process multi-model chat message.")
yield StreamingError(
error=str(e),
error=f"Multi-model requires 2-3 overrides, got {n_models}",
error_code="VALIDATION_ERROR",
is_retryable=True,
)
db_session.rollback()
return
except Exception as e:
logger.exception(f"Failed multi-model chat: {e}")
stack_trace = traceback.format_exc()
if new_msg_req.deep_research:
yield StreamingError(
error=str(e),
stack_trace=stack_trace,
error_code="MULTI_MODEL_ERROR",
error="Multi-model is not supported with deep research",
error_code="VALIDATION_ERROR",
is_retryable=True,
)
db_session.rollback()
finally:
try:
if setup is not None:
set_processing_status(
chat_session_id=setup.chat_session.id,
cache=setup.cache,
value=False,
)
except Exception:
logger.exception("Error clearing processing status")
return
yield from handle_stream_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
)
def llm_loop_completion_handle(

View File

@@ -112,15 +112,15 @@ class TestEmitterStreamingMode:
# --- model_index tagging ---
def test_model_idx_none_preserves_model_index_none(self) -> None:
"""N=1 backwards-compat: model_index must stay None in the packet."""
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
"""N=1: default model_idx=0, so packet gets model_index=0."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=None, merged_queue=mq)
emitter = Emitter(merged_queue=mq) # model_idx defaults to 0
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index is None
assert tagged.placement.model_index == 0
def test_model_idx_zero_tags_packet(self) -> None:
import queue
@@ -162,12 +162,12 @@ class TestEmitterStreamingMode:
key, _ = mq.get_nowait()
assert key == 2
def test_key_is_zero_when_model_idx_none(self) -> None:
"""N=1: key defaults to 0 (single slot in the drain loop)."""
def test_n1_key_is_zero_when_model_idx_default(self) -> None:
"""N=1: default model_idx=0, so drain loop key is 0."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=None, merged_queue=mq)
emitter = Emitter(merged_queue=mq) # model_idx defaults to 0
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 0

View File

@@ -44,8 +44,8 @@ def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverr
return LLMOverride(model_provider=provider, model_version=version)
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
"""Advance the generator one step to trigger early validation."""
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
"""Return the first item yielded by handle_multi_model_stream."""
from onyx.chat.process_message import handle_multi_model_stream
user = MagicMock()
@@ -54,9 +54,7 @@ def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None
db = MagicMock()
gen = handle_multi_model_stream(req, user, db, overrides)
# Calling next() executes until the first yield OR raises.
# Validation errors are raised before any yield.
next(gen)
return next(gen)
# ---------------------------------------------------------------------------
@@ -65,55 +63,64 @@ def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None
class TestRunMultiModelStreamValidation:
def test_single_override_raises(self) -> None:
"""Exactly 1 override is not multi-model — must raise."""
def test_single_override_yields_error(self) -> None:
"""Exactly 1 override is not multi-model — yields StreamingError."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
result = _first_from_stream(req, [_make_override()])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_four_overrides_raises(self) -> None:
"""4 overrides exceeds maximum — must raise."""
def test_four_overrides_yields_error(self) -> None:
"""4 overrides exceeds maximum — yields StreamingError."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
result = _first_from_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_zero_overrides_raises(self) -> None:
"""Empty override list raises."""
def test_zero_overrides_yields_error(self) -> None:
"""Empty override list yields StreamingError."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [])
result = _first_from_stream(req, [])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_deep_research_raises(self) -> None:
"""deep_research=True is incompatible with multi-model."""
def test_deep_research_yields_error(self) -> None:
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
req = _make_request(deep_research=True)
with pytest.raises(ValueError, match="not supported"):
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
result = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
assert isinstance(result, StreamingError)
assert "not supported" in result.error
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
"""Boundary: 1 override yields error, 2 overrides passes validation."""
req = _make_request()
# 1 override must fail
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
# 1 override must yield a StreamingError
result = _first_from_stream(req, [_make_override()])
assert isinstance(
result, StreamingError
), "1 override should yield StreamingError"
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
# missing session, that's OK — validation itself passed)
try:
_start_stream(
result2 = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
except ValueError as exc:
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
if isinstance(result2, StreamingError) and "2-3" in result2.error:
pytest.fail(
f"2 overrides should pass validation, got StreamingError: {result2.error}"
)
except Exception:
pass # Any other error means validation passed
pass # Any non-validation error means validation passed
# ---------------------------------------------------------------------------
@@ -308,8 +315,8 @@ class TestRunModels:
assert isinstance(stop_obj, OverallStop)
assert stop_obj.stop_reason == "complete"
def test_n1_emitted_packet_has_model_index_none(self) -> None:
"""Single-model path: model_index stays None for wire backwards-compat."""
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
def emit_one(**kwargs: Any) -> None:
kwargs["emitter"].emit(
@@ -335,7 +342,7 @@ class TestRunModels:
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index is None
assert reasoning[0].placement.model_index == 0
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
@@ -398,7 +405,7 @@ class TestRunModels:
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
emitter = kwargs["emitter"]
# _model_idx is None for N=1, int for N>1
# _model_idx is always int (0 for N=1, 0/1/2… for N>1)
if emitter._model_idx == 0:
raise RuntimeError("model 0 failed")
emitter.emit(