mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-05 23:12:43 +00:00
Compare commits
12 Commits
v3.1.0-clo
...
multi-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8e565fa75 | ||
|
|
bab95d8bf0 | ||
|
|
eb7bc74e1b | ||
|
|
29da0aefb5 | ||
|
|
6c86301c51 | ||
|
|
631146f48f | ||
|
|
f327278506 | ||
|
|
c7cc439862 | ||
|
|
3365a369e2 | ||
|
|
470bda3fb5 | ||
|
|
13f511e209 | ||
|
|
c5e8ba1eab |
@@ -1,19 +1,8 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
@@ -159,114 +148,3 @@ class ChatStateContainer:
|
||||
"""Thread-safe getter for emitted citations (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
with event streaming capabilities.
|
||||
|
||||
The wrapped function should accept emitter as first arg and use it to emit
|
||||
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
pass
|
||||
"""
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
# Run the function in a background thread
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
|
||||
pkt: Packet | None = None
|
||||
last_turn_index = 0 # Track the highest turn_index seen for stop packet
|
||||
last_cancel_check = time.monotonic()
|
||||
cancel_check_interval = 0.3 # Check for cancellation every 300ms
|
||||
try:
|
||||
while True:
|
||||
# Poll queue with 300ms timeout for natural stop signal checking
|
||||
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
|
||||
try:
|
||||
pkt = emitter.bus.get(timeout=0.3)
|
||||
except Empty:
|
||||
if not is_connected():
|
||||
# Stop signal detected
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = time.monotonic()
|
||||
continue
|
||||
|
||||
if pkt is not None:
|
||||
# Track the highest turn_index for the stop packet
|
||||
if pkt.placement and pkt.placement.turn_index > last_turn_index:
|
||||
last_turn_index = pkt.placement.turn_index
|
||||
|
||||
if isinstance(pkt.obj, OverallStop):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
|
||||
# Check for cancellation periodically even when packets are flowing
|
||||
# This ensures stop signal is checked during active streaming
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_cancel_check >= cancel_check_interval:
|
||||
if not is_connected():
|
||||
# Stop signal detected during streaming
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = current_time
|
||||
finally:
|
||||
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,19 +1,40 @@
|
||||
import threading
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
|
||||
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
Tags every packet with ``model_index`` and places it on ``merged_queue``
|
||||
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
|
||||
|
||||
Args:
|
||||
merged_queue: Shared queue owned by ``_run_models``.
|
||||
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
|
||||
drain_done: Optional event set by ``_run_models`` when the drain loop
|
||||
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
|
||||
immediately so worker threads can exit fast.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
merged_queue: Queue[tuple[int, Packet | Exception | object]],
|
||||
model_idx: int = 0,
|
||||
drain_done: threading.Event | None = None,
|
||||
) -> None:
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
self._drain_done = drain_done
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.bus.put(packet) # Thread-safe
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
if self._drain_done is not None and self._drain_done.is_set():
|
||||
return
|
||||
base = packet.placement or Placement(turn_index=0)
|
||||
tagged = Packet(
|
||||
placement=base.model_copy(update={"model_index": self._model_idx}),
|
||||
obj=packet.obj,
|
||||
)
|
||||
self._merged_queue.put((self._model_idx, tagged))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -617,6 +617,92 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
|
||||
|
||||
Also advances ``latest_child_message_id`` so the preferred response becomes
|
||||
the active branch for any subsequent messages in the conversation.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
|
||||
preferred response is being set.
|
||||
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
|
||||
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If either message is not found, if ``user_message_id`` does not
|
||||
refer to a USER message, or if the assistant message is not a direct child
|
||||
of the user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -839,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -8,6 +8,24 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
"""Per-request LLM settings that override persona defaults.
|
||||
|
||||
All fields are optional — only the fields that differ from the persona's
|
||||
configured LLM need to be supplied. Used both over the wire (API requests)
|
||||
and for multi-model comparison, where one override is supplied per model.
|
||||
|
||||
Attributes:
|
||||
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
|
||||
When ``None``, the persona's default provider is used.
|
||||
model_version: Specific model version string (e.g. ``"gpt-4o"``).
|
||||
When ``None``, the persona's default model is used.
|
||||
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
|
||||
persona's default temperature is used.
|
||||
display_name: Human-readable label shown in the UI for this model,
|
||||
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
|
||||
when not set.
|
||||
"""
|
||||
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -570,6 +575,46 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -660,6 +705,30 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -2,11 +2,25 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
"""Coordinates that identify where a streaming packet belongs in the UI.
|
||||
|
||||
The frontend uses these fields to route each packet to the correct turn,
|
||||
tool tab, agent sub-turn, and (in multi-model mode) response column.
|
||||
|
||||
Attributes:
|
||||
turn_index: Monotonically increasing index of the iterative reasoning block
|
||||
(e.g. tool call round) within this chat message. Lower values happened first.
|
||||
tab_index: Disambiguates parallel tool calls within the same turn so each
|
||||
tool's output can be displayed in its own tab.
|
||||
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
|
||||
top-level packets; an integer for tool-within-tool output.
|
||||
model_index: Which model this packet belongs to. ``0`` for single-model
|
||||
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
|
||||
for pre-LLM setup packets (e.g. message ID info) that are yielded
|
||||
before any Emitter runs.
|
||||
"""
|
||||
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -708,7 +709,6 @@ def run_research_agent_calls(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from queue import Queue
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
@@ -744,8 +744,8 @@ if __name__ == "__main__":
|
||||
if user is None:
|
||||
raise ValueError("No users found in database. Please create a user first.")
|
||||
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
emitter_queue: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(merged_queue=emitter_queue)
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
tool_dict = construct_tools(
|
||||
@@ -792,4 +792,4 @@ if __name__ == "__main__":
|
||||
print(result.intermediate_report)
|
||||
print("=" * 80)
|
||||
print(f"Citations: {result.citation_mapping}")
|
||||
print(f"Total packets emitted: {bus.qsize()}")
|
||||
print(f"Total packets emitted: {emitter_queue.qsize()}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
import json
|
||||
import queue
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
@@ -11,7 +12,6 @@ import requests
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
|
||||
# Use default emitter if none provided
|
||||
# Use a discard emitter if none provided (packets go nowhere)
|
||||
if emitter is None:
|
||||
emitter = get_default_emitter()
|
||||
emitter = Emitter(merged_queue=queue.Queue())
|
||||
|
||||
return [
|
||||
CustomTool(
|
||||
@@ -367,7 +367,7 @@ if __name__ == "__main__":
|
||||
tools = build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=openapi_schema,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
dynamic_schema_info=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,11 +27,13 @@ def create_placement(
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
model_index: int | None = 0,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
model_index=model_index,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ This test:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -20,7 +21,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -137,7 +138,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -200,7 +201,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -275,7 +276,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -350,7 +351,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -458,7 +459,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -541,7 +542,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
|
||||
@@ -8,6 +8,7 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
@@ -16,7 +17,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -174,7 +175,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -232,7 +233,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -284,7 +285,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -345,7 +346,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -416,7 +417,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -483,7 +484,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -536,7 +537,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Unit tests for the Emitter class.
|
||||
|
||||
All tests use the streaming mode (merged_queue required). Emitter has a single
|
||||
code path — no standalone bus.
|
||||
"""
|
||||
|
||||
import queue
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _placement(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
|
||||
def _packet(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Packet:
|
||||
"""Build a minimal valid packet with an OverallStop payload."""
|
||||
return Packet(
|
||||
placement=_placement(turn_index, tab_index, sub_turn_index),
|
||||
obj=OverallStop(stop_reason="test"),
|
||||
)
|
||||
|
||||
|
||||
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
|
||||
"""Return (emitter, queue) wired together."""
|
||||
mq: queue.Queue = queue.Queue()
|
||||
return Emitter(merged_queue=mq, model_idx=model_idx), mq
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueRouting:
|
||||
def test_emit_lands_on_merged_queue(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet())
|
||||
assert not mq.empty()
|
||||
|
||||
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
item = mq.get_nowait()
|
||||
assert isinstance(item, tuple)
|
||||
assert len(item) == 2
|
||||
|
||||
def test_multiple_packets_delivered_fifo(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
p1 = _packet(turn_index=0)
|
||||
p2 = _packet(turn_index=1)
|
||||
emitter.emit(p1)
|
||||
emitter.emit(p2)
|
||||
_, t1 = mq.get_nowait()
|
||||
_, t2 = mq.get_nowait()
|
||||
assert t1.placement.turn_index == 0
|
||||
assert t2.placement.turn_index == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_index tagging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterModelIndexTagging:
|
||||
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
|
||||
"""N=1: default model_idx=0, so packet gets model_index=0."""
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 0
|
||||
|
||||
def test_model_idx_one_tags_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 1
|
||||
|
||||
def test_model_idx_two_tags_packet(self) -> None:
|
||||
"""Boundary: third model in a 3-model run."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueKey:
|
||||
def test_key_equals_model_idx(self) -> None:
|
||||
"""Drain loop uses the key to route packets; it must match model_idx."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 2
|
||||
|
||||
def test_n1_key_is_zero(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Placement field preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterPlacementPreservation:
|
||||
def test_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(turn_index=5))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.turn_index == 5
|
||||
|
||||
def test_tab_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(tab_index=3))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.tab_index == 3
|
||||
|
||||
def test_sub_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=2))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index == 2
|
||||
|
||||
def test_sub_turn_index_none_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=None))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index is None
|
||||
|
||||
def test_packet_obj_is_not_modified(self) -> None:
|
||||
"""The payload object must survive tagging untouched."""
|
||||
emitter, mq = _make_emitter()
|
||||
original_obj = OverallStop(stop_reason="sentinel")
|
||||
pkt = Packet(placement=_placement(), obj=original_obj)
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.obj is original_obj
|
||||
|
||||
def test_different_obj_types_are_handled(self) -> None:
|
||||
"""Any valid PacketObj type passes through correctly."""
|
||||
emitter, mq = _make_emitter()
|
||||
pkt = Packet(placement=_placement(), obj=ReasoningStart())
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert isinstance(tagged.obj, ReasoningStart)
|
||||
768
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
768
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
@@ -0,0 +1,768 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in handle_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restore_ee_version() -> Generator[None, None, None]:
|
||||
"""Reset EE global state after each test.
|
||||
|
||||
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
|
||||
(via the celery client import chain). Without this fixture, the EE flag stays
|
||||
True for the rest of the session and breaks unrelated tests that mock Confluence
|
||||
or other connectors and assume EE is disabled.
|
||||
"""
|
||||
original = global_version._is_ee
|
||||
yield
|
||||
global_version._is_ee = original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
|
||||
"""Return the first item yielded by handle_multi_model_stream."""
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = handle_multi_model_stream(req, user, db, overrides)
|
||||
return next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_yields_error(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_four_overrides_yields_error(self) -> None:
|
||||
"""4 overrides exceeds maximum — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_zero_overrides_yields_error(self) -> None:
|
||||
"""Empty override list yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_deep_research_yields_error(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
|
||||
req = _make_request(deep_research=True)
|
||||
result = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "not supported" in result.error
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override yields error, 2 overrides passes validation."""
|
||||
req = _make_request()
|
||||
# 1 override must yield a StreamingError
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(
|
||||
result, StreamingError
|
||||
), "1 override should yield StreamingError"
|
||||
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
|
||||
# missing session, that's OK — validation itself passed)
|
||||
try:
|
||||
result2 = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
if isinstance(result2, StreamingError) and "2-3" in result2.error:
|
||||
pytest.fail(
|
||||
f"2 overrides should pass validation, got StreamingError: {result2.error}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Any non-validation error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
assert user_msg.latest_child_message_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_models — drain loop behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_setup(n_models: int = 1) -> MagicMock:
|
||||
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
|
||||
setup = MagicMock()
|
||||
setup.llms = [MagicMock() for _ in range(n_models)]
|
||||
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
|
||||
setup.reserved_token_count = 100
|
||||
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
|
||||
# constructors inside _run_model — must be typed correctly for Pydantic.
|
||||
setup.new_msg_req.deep_research = False
|
||||
setup.new_msg_req.internal_search_filters = None
|
||||
setup.new_msg_req.allowed_tool_ids = None
|
||||
setup.new_msg_req.include_citations = True
|
||||
setup.search_params.project_id_filter = None
|
||||
setup.search_params.persona_id_filter = None
|
||||
setup.bypass_acl = False
|
||||
setup.slack_context = None
|
||||
setup.available_files.user_file_ids = []
|
||||
setup.available_files.chat_file_ids = []
|
||||
setup.forced_tool_id = None
|
||||
setup.simple_chat_history = []
|
||||
setup.chat_session.id = uuid4()
|
||||
setup.user_message.id = None
|
||||
setup.custom_tool_additional_headers = None
|
||||
setup.mcp_headers = None
|
||||
return setup
|
||||
|
||||
|
||||
def _run_models_collect(setup: MagicMock) -> list:
|
||||
"""Drive _run_models to completion and return all yielded items."""
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
return list(_run_models(setup, MagicMock(), MagicMock()))
|
||||
|
||||
|
||||
class TestRunModels:
|
||||
"""Tests for the _run_models worker-thread drain loop.
|
||||
|
||||
All external dependencies (LLM, DB, tools) are patched out. Worker threads
|
||||
still run but return immediately since run_llm_loop is mocked.
|
||||
"""
|
||||
|
||||
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
|
||||
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
|
||||
|
||||
def emit_stop(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(stop_reason="complete"),
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert len(stops) == 1
|
||||
stop_obj = stops[0].obj
|
||||
assert isinstance(stop_obj, OverallStop)
|
||||
assert stop_obj.stop_reason == "complete"
|
||||
|
||||
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
|
||||
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 0
|
||||
|
||||
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
|
||||
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
# _model_idx is set by _run_model based on position in setup.llms
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 2
|
||||
indices = {p.placement.model_index for p in reasoning}
|
||||
assert indices == {0, 1}
|
||||
|
||||
def test_model_error_yields_streaming_error(self) -> None:
|
||||
"""An exception inside a worker thread is surfaced as a StreamingError."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].error_code == "MODEL_ERROR"
|
||||
assert "intentional test failure" in errors[0].error
|
||||
|
||||
def test_one_model_error_does_not_stop_other_models(self) -> None:
|
||||
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
|
||||
if kwargs["llm"] is setup.llms[0]:
|
||||
raise RuntimeError("model 0 failed")
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=fail_model_0_succeed_model_1,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 1
|
||||
|
||||
def test_cancellation_yields_user_cancelled_stop(self) -> None:
|
||||
"""If check_is_connected returns False, drain loop emits user_cancelled."""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert any(
|
||||
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
|
||||
for s in stops
|
||||
)
|
||||
|
||||
def test_stop_button_calls_completion_for_all_models(self) -> None:
|
||||
"""llm_loop_completion_handle must be called for all models when the stop button fires.
|
||||
|
||||
Regression test for the disconnect-cleanup bug: the old
|
||||
run_chat_loop_with_state_containers always called completion_callback in
|
||||
its finally block (even on disconnect) so the DB message was updated from
|
||||
the TERMINATED placeholder to a partial answer. The new _run_models must
|
||||
replicate this — otherwise the integration test
|
||||
test_send_message_disconnect_and_cleanup fails because the message stays
|
||||
as "Response was terminated prior to completion, try regenerating."
|
||||
"""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Must be called once per model, not zero times
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_called_for_each_successful_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be called once per model that succeeded."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_not_called_for_failed_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be skipped for a model that raised."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("fail")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
mock_handle.assert_not_called()
|
||||
|
||||
def test_http_disconnect_completion_via_generator_exit(self) -> None:
|
||||
"""GeneratorExit from HTTP disconnect triggers worker self-completion.
|
||||
|
||||
When the HTTP client closes the connection, Starlette throws GeneratorExit
|
||||
into the stream generator. The finally block sets drain_done (signalling
|
||||
emitters to stop blocking) and calls executor.shutdown(wait=False) so the
|
||||
server thread is never blocked. Worker threads detect drain_done.is_set()
|
||||
after run_llm_loop completes and self-persist the result via
|
||||
llm_loop_completion_handle using their own DB session.
|
||||
|
||||
This is the primary regression for test_send_message_disconnect_and_cleanup:
|
||||
the integration test disconnects mid-stream and expects the DB message to be
|
||||
updated from the TERMINATED placeholder to the real response.
|
||||
"""
|
||||
import threading
|
||||
|
||||
# Signals the worker to unblock from run_llm_loop after gen.close() returns.
|
||||
# This guarantees drain_done is set BEFORE the worker returns from run_llm_loop,
|
||||
# so the self-completion path (drain_done.is_set() check) is always taken.
|
||||
disconnect_received = threading.Event()
|
||||
# Set by the llm_loop_completion_handle mock when called.
|
||||
completion_called = threading.Event()
|
||||
|
||||
def emit_then_complete(**kwargs: Any) -> None:
|
||||
"""Emit one packet (to give the drain loop a yield point), then block
|
||||
until the main thread signals that gen.close() has been called. This
|
||||
ensures drain_done is set before we return so model_succeeded is checked
|
||||
against a set drain_done — no race condition.
|
||||
"""
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
disconnect_received.wait(timeout=5)
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_then_complete,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
# cast to Generator so .close() is available; _run_models returns
|
||||
# AnswerStream (= Iterator) but the actual object is always a generator.
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
# Advance to the first yielded packet — generator suspends at `yield item`.
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
# Simulate Starlette closing the stream on HTTP client disconnect.
|
||||
# GeneratorExit is thrown at the `yield item` suspension point.
|
||||
gen.close()
|
||||
# Unblock the worker now that drain_done has been set by gen.close().
|
||||
disconnect_received.set()
|
||||
|
||||
# Worker self-completes asynchronously (executor.shutdown(wait=False)).
|
||||
# Wait here, inside the patch context, so that get_session_with_current_tenant
|
||||
# and llm_loop_completion_handle mocks are still active when the worker calls them.
|
||||
assert completion_called.wait(
|
||||
timeout=5
|
||||
), "worker must self-complete via drain_done within 5 seconds"
|
||||
assert (
|
||||
mock_handle.call_count == 1
|
||||
), "completion handle must be called once for the successful model"
|
||||
|
||||
def test_b1_race_disconnect_handler_completes_already_finished_model(self) -> None:
|
||||
"""B1 regression: model finishes BEFORE GeneratorExit fires.
|
||||
|
||||
The worker exits _run_model with drain_done.is_set()=False and skips
|
||||
self-completion. When gen.close() fires afterward, the finally else-branch
|
||||
must detect model_succeeded=True and call llm_loop_completion_handle itself.
|
||||
|
||||
Contrast with test_http_disconnect_completion_via_generator_exit, which
|
||||
tests the opposite ordering (worker finishes AFTER disconnect).
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
completion_called = threading.Event()
|
||||
|
||||
def emit_and_return_immediately(**kwargs: Any) -> None:
|
||||
# Emit one packet so the drain loop has something to yield, then return
|
||||
# immediately — no blocking. The worker will be done in microseconds.
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_and_return_immediately,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
|
||||
# Give the worker thread time to finish completely (emit + return +
|
||||
# finally + self-completion check). It does almost no work, so 100 ms
|
||||
# is far more than enough while still keeping the test fast.
|
||||
time.sleep(0.1)
|
||||
|
||||
# Now close — worker is already done, so else-branch handles completion.
|
||||
gen.close()
|
||||
|
||||
assert completion_called.wait(
|
||||
timeout=5
|
||||
), "disconnect handler must call completion for a model that already finished"
|
||||
assert mock_handle.call_count == 1, "completion must be called exactly once"
|
||||
|
||||
def test_stop_button_does_not_call_completion_for_errored_model(self) -> None:
|
||||
"""B2 regression: stop-button must NOT call completion for an errored model.
|
||||
|
||||
When model 0 raises an exception, its reserved ChatMessage must not be
|
||||
saved with 'stopped by user' — that message is wrong for a model that
|
||||
errored. llm_loop_completion_handle must only be called for non-errored
|
||||
models when the stop button fires.
|
||||
"""
|
||||
|
||||
def fail_model_0(**kwargs: Any) -> None:
|
||||
if kwargs["llm"] is setup.llms[0]:
|
||||
raise RuntimeError("model 0 errored")
|
||||
# Model 1: run forever (stop button fires before it finishes)
|
||||
time.sleep(10)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
# Return False immediately so the stop-button path fires while model 1
|
||||
# is still sleeping (model 0 has already errored by then).
|
||||
setup.check_is_connected = lambda: False
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Completion must NOT be called for model 0 (it errored).
|
||||
# It MAY be called for model 1 (still in-flight when stop fired).
|
||||
for call in mock_handle.call_args_list:
|
||||
assert (
|
||||
call.kwargs.get("llm") is not setup.llms[0]
|
||||
), "llm_loop_completion_handle must not be called for the errored model"
|
||||
|
||||
def test_external_state_container_used_for_model_zero(self) -> None:
|
||||
"""When provided, external_state_container is used as state_containers[0]."""
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
external = ChatStateContainer()
|
||||
setup = _make_setup(n_models=1)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
list(
|
||||
_run_models(
|
||||
setup, MagicMock(), MagicMock(), external_state_container=external
|
||||
)
|
||||
)
|
||||
|
||||
# The state_container kwarg passed to run_llm_loop must be the external one
|
||||
call_kwargs = mock_llm.call_args.kwargs
|
||||
assert call_kwargs["state_container"] is external
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for memory tool streaming packet emissions."""
|
||||
|
||||
from queue import Queue
|
||||
import queue
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -18,9 +18,13 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter() -> Emitter:
|
||||
bus: Queue = Queue()
|
||||
return Emitter(bus)
|
||||
def emitter_queue() -> queue.Queue:
|
||||
return queue.Queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter(emitter_queue: queue.Queue) -> Emitter:
|
||||
return Emitter(merged_queue=emitter_queue)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -53,24 +57,27 @@ class TestMemoryToolEmitStart:
|
||||
def test_emit_start_emits_memory_tool_start_packet(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
) -> None:
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolStart)
|
||||
assert packet.placement == placement
|
||||
assert packet.placement is not None
|
||||
assert packet.placement.turn_index == placement.turn_index
|
||||
assert packet.placement.tab_index == placement.tab_index
|
||||
assert packet.placement.model_index == 0 # emitter stamps model_index=0
|
||||
|
||||
def test_emit_start_with_different_placement(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
) -> None:
|
||||
placement = Placement(turn_index=2, tab_index=1)
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert packet.placement.turn_index == 2
|
||||
assert packet.placement.tab_index == 1
|
||||
|
||||
@@ -81,7 +88,7 @@ class TestMemoryToolRun:
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -93,21 +100,19 @@ class TestMemoryToolRun:
|
||||
memory="User prefers Python",
|
||||
)
|
||||
|
||||
# The delta packet should be in the queue
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers Python"
|
||||
assert packet.obj.operation == "add"
|
||||
assert packet.obj.memory_id is None
|
||||
assert packet.obj.index is None
|
||||
assert packet.placement == placement
|
||||
|
||||
@patch("onyx.tools.tool_implementations.memory.memory_tool.process_memory_update")
|
||||
def test_run_emits_delta_for_update_operation(
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -119,7 +124,7 @@ class TestMemoryToolRun:
|
||||
memory="User prefers light mode",
|
||||
)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers light mode"
|
||||
assert packet.obj.operation == "update"
|
||||
|
||||
@@ -182,7 +182,8 @@ export async function* sendMessage({
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
const data = await response.json().catch(() => ({}));
|
||||
throw new Error(data.detail ?? `HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
yield* handleSSEStream<PacketType>(response, signal);
|
||||
|
||||
@@ -901,6 +901,11 @@ export default function useChatController({
|
||||
});
|
||||
}
|
||||
}
|
||||
// Surface FIFO errors (e.g. 429 before any packets arrive) so the
|
||||
// catch block replaces the thinking placeholder with an error message.
|
||||
if (stack.error) {
|
||||
throw new Error(stack.error);
|
||||
}
|
||||
} catch (e: any) {
|
||||
console.log("Error:", e);
|
||||
const errorMsg = e.message;
|
||||
|
||||
Reference in New Issue
Block a user