mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-01 04:52:43 +00:00
refactor(chat): unify N=1/N>1 path, simplify Emitter, tighten timeouts
- handle_stream_message_objects gains llm_overrides param; handle_multi_model_stream becomes a thin validation wrapper that delegates to it (Evan's unification request) - Emitter.model_idx defaults to 0 instead of None; removes None-branch in emit() - All merged_queue.put() timeouts unified at 3s (was 1s in Emitter, 5s in _run_model) - Add comment warning against shared DB state writes inside _run_model worker threads - Update unit tests to reflect model_index=0 for N=1 and StreamingError validation shape
This commit is contained in:
@@ -22,9 +22,8 @@ class Emitter:
|
||||
(tests, eval harnesses, custom-tool scripts) work without modification.
|
||||
|
||||
Args:
|
||||
model_idx: Index embedded in packet placements. Pass ``None`` for single-model
|
||||
runs to preserve the backwards-compatible wire format (``model_index=None``
|
||||
in the packet); pass an integer for each model in a multi-model run.
|
||||
model_idx: Index embedded in packet placements. Defaults to ``0`` for N=1
|
||||
runs. Each model in a multi-model run receives its own index (0, 1, 2…).
|
||||
merged_queue: Shared queue owned by the ``_run_models`` drain loop. When set,
|
||||
all ``emit()`` calls route here instead of ``self.bus``.
|
||||
|
||||
@@ -42,7 +41,7 @@ class Emitter:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_idx: int | None = None,
|
||||
model_idx: int = 0,
|
||||
merged_queue: "queue.Queue | None" = None,
|
||||
) -> None:
|
||||
self._model_idx = model_idx
|
||||
@@ -70,9 +69,9 @@ class Emitter:
|
||||
model_index=self._model_idx,
|
||||
)
|
||||
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
|
||||
key = self._model_idx if self._model_idx is not None else 0
|
||||
key = self._model_idx
|
||||
try:
|
||||
self._merged_queue.put((key, tagged_packet), timeout=1.0)
|
||||
self._merged_queue.put((key, tagged_packet), timeout=3.0)
|
||||
except queue.Full:
|
||||
# Drain loop is gone (e.g. GeneratorExit on disconnect); discard packet.
|
||||
pass
|
||||
|
||||
@@ -1021,16 +1021,17 @@ def _run_models(
|
||||
model_idx: Zero-based index of the model in ``setup.llms``. Determines
|
||||
which LLM and state container this thread operates on.
|
||||
"""
|
||||
# N=1: model_idx=None keeps model_index as None in packets (backwards compat).
|
||||
# N>1: model_idx=int tags packets with the model's index.
|
||||
model_emitter = Emitter(
|
||||
model_idx=model_idx if n_models > 1 else None,
|
||||
model_idx=model_idx,
|
||||
merged_queue=merged_queue,
|
||||
)
|
||||
sc = state_containers[model_idx]
|
||||
model_llm = setup.llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
|
||||
# Do NOT write to the outer db_session (or any shared DB state) from here;
|
||||
# all DB writes in this thread must go through thread_db_session.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
@@ -1118,13 +1119,13 @@ def _run_models(
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
merged_queue.put((model_idx, e), timeout=5.0)
|
||||
merged_queue.put((model_idx, e), timeout=3.0)
|
||||
except queue.Full:
|
||||
pass # Drain loop gone (GeneratorExit); thread exits cleanly
|
||||
|
||||
finally:
|
||||
try:
|
||||
merged_queue.put((model_idx, _MODEL_DONE), timeout=5.0)
|
||||
merged_queue.put((model_idx, _MODEL_DONE), timeout=3.0)
|
||||
except queue.Full:
|
||||
pass # Drain loop gone (GeneratorExit); thread exits cleanly
|
||||
|
||||
@@ -1290,8 +1291,7 @@ def handle_stream_message_objects(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
llm_overrides: list[LLMOverride] | None = None,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
@@ -1306,7 +1306,7 @@ def handle_stream_message_objects(
|
||||
# Optional external state container for non-streaming access to accumulated state
|
||||
external_state_container: ChatStateContainer | None = None,
|
||||
) -> AnswerStream:
|
||||
"""Primary entrypoint for a single-model chat turn.
|
||||
"""Primary entrypoint for single-model and multi-model chat turns.
|
||||
|
||||
Builds the turn context via ``build_chat_turn``, then streams packets from
|
||||
``_run_models`` back to the caller. Handles setup errors, LLM errors, and
|
||||
@@ -1317,6 +1317,8 @@ def handle_stream_message_objects(
|
||||
new_msg_req: The incoming chat request from the user.
|
||||
user: Authenticated user; may be anonymous for public personas.
|
||||
db_session: Database session for this request.
|
||||
llm_overrides: ``None`` → single-model (persona default LLM).
|
||||
Non-empty list → multi-model (one LLM per override, 2–3 items).
|
||||
litellm_additional_headers: Extra headers forwarded to the LLM provider.
|
||||
custom_tool_additional_headers: Extra headers for custom tool HTTP calls.
|
||||
mcp_headers: Extra headers for MCP tool calls.
|
||||
@@ -1345,7 +1347,7 @@ def handle_stream_message_objects(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=None,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
@@ -1478,13 +1480,10 @@ def handle_multi_model_stream(
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
) -> AnswerStream:
|
||||
"""Entrypoint for side-by-side multi-model comparison (2–3 models).
|
||||
"""Thin wrapper for side-by-side multi-model comparison (2–3 models).
|
||||
|
||||
Validates the override list, builds the turn context via ``build_chat_turn``
|
||||
(which reserves one ``ChatMessage`` row per model), then streams interleaved
|
||||
packets from ``_run_models`` back to the caller. Each packet carries a
|
||||
``model_index`` in its placement so the frontend can route it to the correct
|
||||
response column.
|
||||
Validates the override list and delegates to ``handle_stream_message_objects``,
|
||||
which handles both single-model and multi-model execution via the same path.
|
||||
|
||||
Args:
|
||||
new_msg_req: The incoming chat request. ``deep_research`` must be ``False``.
|
||||
@@ -1497,63 +1496,32 @@ def handle_multi_model_stream(
|
||||
|
||||
Returns:
|
||||
Generator yielding interleaved ``Packet`` objects from all models, each tagged
|
||||
with ``model_index`` in its placement. Terminates with one
|
||||
``ChatStateContainerPacket`` per model.
|
||||
|
||||
Raises:
|
||||
ValueError: (yielded as ``StreamingError``) if ``llm_overrides`` is not 2–3
|
||||
items, or if ``deep_research`` is ``True``.
|
||||
with ``model_index`` in its placement.
|
||||
"""
|
||||
n_models = len(llm_overrides)
|
||||
if n_models < 2 or n_models > 3:
|
||||
raise ValueError(f"Multi-model requires 2-3 overrides, got {n_models}")
|
||||
if new_msg_req.deep_research:
|
||||
raise ValueError("Multi-model is not supported with deep research")
|
||||
|
||||
setup: ChatTurnSetup | None = None
|
||||
try:
|
||||
setup = yield from build_chat_turn(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
)
|
||||
yield from _run_models(setup, user, db_session)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process multi-model chat message.")
|
||||
yield StreamingError(
|
||||
error=str(e),
|
||||
error=f"Multi-model requires 2-3 overrides, got {n_models}",
|
||||
error_code="VALIDATION_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed multi-model chat: {e}")
|
||||
stack_trace = traceback.format_exc()
|
||||
if new_msg_req.deep_research:
|
||||
yield StreamingError(
|
||||
error=str(e),
|
||||
stack_trace=stack_trace,
|
||||
error_code="MULTI_MODEL_ERROR",
|
||||
error="Multi-model is not supported with deep research",
|
||||
error_code="VALIDATION_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
finally:
|
||||
try:
|
||||
if setup is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
cache=setup.cache,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error clearing processing status")
|
||||
return
|
||||
yield from handle_stream_message_objects(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
)
|
||||
|
||||
|
||||
def llm_loop_completion_handle(
|
||||
|
||||
@@ -112,15 +112,15 @@ class TestEmitterStreamingMode:
|
||||
|
||||
# --- model_index tagging ---
|
||||
|
||||
def test_model_idx_none_preserves_model_index_none(self) -> None:
|
||||
"""N=1 backwards-compat: model_index must stay None in the packet."""
|
||||
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
|
||||
"""N=1: default model_idx=0, so packet gets model_index=0."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=None, merged_queue=mq)
|
||||
emitter = Emitter(merged_queue=mq) # model_idx defaults to 0
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index is None
|
||||
assert tagged.placement.model_index == 0
|
||||
|
||||
def test_model_idx_zero_tags_packet(self) -> None:
|
||||
import queue
|
||||
@@ -162,12 +162,12 @@ class TestEmitterStreamingMode:
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 2
|
||||
|
||||
def test_key_is_zero_when_model_idx_none(self) -> None:
|
||||
"""N=1: key defaults to 0 (single slot in the drain loop)."""
|
||||
def test_n1_key_is_zero_when_model_idx_default(self) -> None:
|
||||
"""N=1: default model_idx=0, so drain loop key is 0."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=None, merged_queue=mq)
|
||||
emitter = Emitter(merged_queue=mq) # model_idx defaults to 0
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 0
|
||||
|
||||
@@ -44,8 +44,8 @@ def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverr
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
|
||||
"""Advance the generator one step to trigger early validation."""
|
||||
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
|
||||
"""Return the first item yielded by handle_multi_model_stream."""
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
@@ -54,9 +54,7 @@ def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None
|
||||
db = MagicMock()
|
||||
|
||||
gen = handle_multi_model_stream(req, user, db, overrides)
|
||||
# Calling next() executes until the first yield OR raises.
|
||||
# Validation errors are raised before any yield.
|
||||
next(gen)
|
||||
return next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -65,55 +63,64 @@ def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_raises(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — must raise."""
|
||||
def test_single_override_yields_error(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — yields StreamingError."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_four_overrides_raises(self) -> None:
|
||||
"""4 overrides exceeds maximum — must raise."""
|
||||
def test_four_overrides_yields_error(self) -> None:
|
||||
"""4 overrides exceeds maximum — yields StreamingError."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
result = _first_from_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_zero_overrides_raises(self) -> None:
|
||||
"""Empty override list raises."""
|
||||
def test_zero_overrides_yields_error(self) -> None:
|
||||
"""Empty override list yields StreamingError."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [])
|
||||
result = _first_from_stream(req, [])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_deep_research_raises(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model."""
|
||||
def test_deep_research_yields_error(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
|
||||
req = _make_request(deep_research=True)
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
result = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "not supported" in result.error
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
|
||||
"""Boundary: 1 override yields error, 2 overrides passes validation."""
|
||||
req = _make_request()
|
||||
# 1 override must fail
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
|
||||
# 1 override must yield a StreamingError
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(
|
||||
result, StreamingError
|
||||
), "1 override should yield StreamingError"
|
||||
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
|
||||
# missing session, that's OK — validation itself passed)
|
||||
try:
|
||||
_start_stream(
|
||||
result2 = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
except ValueError as exc:
|
||||
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
|
||||
if isinstance(result2, StreamingError) and "2-3" in result2.error:
|
||||
pytest.fail(
|
||||
f"2 overrides should pass validation, got StreamingError: {result2.error}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Any other error means validation passed
|
||||
pass # Any non-validation error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -308,8 +315,8 @@ class TestRunModels:
|
||||
assert isinstance(stop_obj, OverallStop)
|
||||
assert stop_obj.stop_reason == "complete"
|
||||
|
||||
def test_n1_emitted_packet_has_model_index_none(self) -> None:
|
||||
"""Single-model path: model_index stays None for wire backwards-compat."""
|
||||
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
|
||||
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
@@ -335,7 +342,7 @@ class TestRunModels:
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index is None
|
||||
assert reasoning[0].placement.model_index == 0
|
||||
|
||||
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
|
||||
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
|
||||
@@ -398,7 +405,7 @@ class TestRunModels:
|
||||
|
||||
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
|
||||
emitter = kwargs["emitter"]
|
||||
# _model_idx is None for N=1, int for N>1
|
||||
# _model_idx is always int (0 for N=1, 0/1/2… for N>1)
|
||||
if emitter._model_idx == 0:
|
||||
raise RuntimeError("model 0 failed")
|
||||
emitter.emit(
|
||||
|
||||
Reference in New Issue
Block a user