Compare commits

...

1 Commits

Author SHA1 Message Date
Nik
6dbf26ca04 feat(chat): add DB schema and Pydantic models for multi-model answers - 1/8
Add the foundational schema and type changes needed for multi-model
answer generation, where users can compare responses from up to 3 LLMs
side-by-side.

- Alembic migration: add `preferred_response_id` FK and
  `model_display_name` columns to `chat_message`
- Extend `SendMessageRequest` with `llm_overrides: list[LLMOverride]`
- Add `model_index` to `Placement` for streaming packet routing
- New `MultiModelMessageResponseIDInfo` packet type
- Extend `ChatMessageDetail` with `preferred_response_id` and
  `model_display_name`
2026-03-12 12:21:23 -07:00
6 changed files with 201 additions and 1 deletions

View File

@@ -0,0 +1,42 @@
"""add multi-model columns to chat_message
Revision ID: a3f8b2c1d4e5
Revises: 27fb147a843f
Create Date: 2026-03-12 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a3f8b2c1d4e5"
down_revision = "27fb147a843f"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column(
"preferred_response_id",
sa.Integer(),
sa.ForeignKey("chat_message.id"),
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column(
"model_display_name",
sa.String(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("chat_message", "model_display_name")
op.drop_column("chat_message", "preferred_response_id")

View File

@@ -8,6 +8,7 @@ from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import InMemoryChatFile
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import Packet
@@ -35,7 +36,13 @@ class CreateChatSessionID(BaseModel):
chat_session_id: UUID
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
AnswerStreamPart = (
Packet
| MessageResponseIDInfo
| MultiModelMessageResponseIDInfo
| StreamingError
| CreateChatSessionID
)
AnswerStream = Iterator[AnswerStreamPart]

View File

@@ -2622,6 +2622,18 @@ class ChatMessage(Base):
ForeignKey("chat_message.id"), nullable=True
)
# For multi-model turns: the user message points to which assistant response
# was selected as the preferred one to continue the conversation with.
# Only set on user messages that triggered a multi-model generation.
preferred_response_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id"), nullable=True
)
# The display name of the model that generated this assistant message
# (e.g. "GPT-4", "Claude Opus"). Used on session reload to label
# multi-model response panels and <> navigation arrows.
model_display_name: Mapped[str | None] = mapped_column(String, nullable=True)
# Only set on summary messages - the ID of the last message included in this summary
# Used for chat history compression
last_summarized_message_id: Mapped[int | None] = mapped_column(
@@ -2696,6 +2708,12 @@ class ChatMessage(Base):
remote_side="ChatMessage.id",
)
preferred_response: Mapped["ChatMessage | None"] = relationship(
"ChatMessage",
foreign_keys=[preferred_response_id],
remote_side="ChatMessage.id",
)
# Chat messages only need to know their immediate tool call children
# If there are nested tool calls, they are stored in the tool_call_children relationship.
tool_calls: Mapped[list["ToolCall"] | None] = relationship(

View File

@@ -41,6 +41,16 @@ class MessageResponseIDInfo(BaseModel):
reserved_assistant_message_id: int
class MultiModelMessageResponseIDInfo(BaseModel):
"""Sent at the start of a multi-model streaming response.
Contains the user message ID and the reserved assistant message IDs
for each model being run in parallel."""
user_message_id: int | None
reserved_assistant_message_ids: list[int]
model_names: list[str]
class SourceTag(Tag):
source: DocumentSource
@@ -86,6 +96,10 @@ class SendMessageRequest(BaseModel):
message: str
llm_override: LLMOverride | None = None
# For multi-model mode: up to 3 LLM overrides to run in parallel.
# When provided with >1 entry, triggers multi-model streaming.
# Backward-compat: if only `llm_override` is set, single-model path is used.
llm_overrides: list[LLMOverride] | None = None
# Test-only override for deterministic LiteLLM mock responses.
mock_llm_response: str | None = None
@@ -211,6 +225,10 @@ class ChatMessageDetail(BaseModel):
error: str | None = None
current_feedback: str | None = None # "like" | "dislike" | null
processing_duration_seconds: float | None = None
# For multi-model turns: the preferred assistant response ID (set on user messages only)
preferred_response_id: int | None = None
# The display name of the model that generated this message (e.g. "GPT-4", "Claude Opus")
model_display_name: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore

View File

@@ -8,3 +8,6 @@ class Placement(BaseModel):
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
# None for single-model (default) responses.
model_index: int | None = None

View File

@@ -0,0 +1,112 @@
"""Unit tests for multi-model schema and Pydantic model additions."""
from datetime import datetime
from onyx.configs.constants import MessageType
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
def test_placement_model_index_default_none() -> None:
p = Placement(turn_index=0)
assert p.model_index is None
def test_placement_model_index_set() -> None:
p = Placement(turn_index=0, model_index=2)
assert p.model_index == 2
def test_placement_serialization_with_model_index() -> None:
p = Placement(turn_index=1, tab_index=0, model_index=1)
data = p.model_dump()
assert data["model_index"] == 1
restored = Placement(**data)
assert restored.model_index == 1
def test_multi_model_message_response_id_info() -> None:
info = MultiModelMessageResponseIDInfo(
user_message_id=42,
reserved_assistant_message_ids=[100, 101, 102],
model_names=["gpt-4", "claude-3-opus", "gemini-pro"],
)
data = info.model_dump()
assert data["user_message_id"] == 42
assert len(data["reserved_assistant_message_ids"]) == 3
assert len(data["model_names"]) == 3
def test_multi_model_message_response_id_info_null_user() -> None:
info = MultiModelMessageResponseIDInfo(
user_message_id=None,
reserved_assistant_message_ids=[10],
model_names=["gpt-4"],
)
assert info.user_message_id is None
def test_send_message_request_llm_overrides_none_by_default() -> None:
req = SendMessageRequest(
message="hello",
chat_session_id="00000000-0000-0000-0000-000000000001",
)
assert req.llm_overrides is None
assert req.llm_override is None
def test_send_message_request_with_llm_overrides() -> None:
overrides = [
LLMOverride(model_provider="openai", model_version="gpt-4"),
LLMOverride(model_provider="anthropic", model_version="claude-3-opus"),
]
req = SendMessageRequest(
message="compare these",
chat_session_id="00000000-0000-0000-0000-000000000001",
llm_overrides=overrides,
)
assert req.llm_overrides is not None
assert len(req.llm_overrides) == 2
def test_send_message_request_backward_compat_single_override() -> None:
"""Existing single llm_override still works alongside new llm_overrides field."""
req = SendMessageRequest(
message="single model",
chat_session_id="00000000-0000-0000-0000-000000000001",
llm_override=LLMOverride(model_provider="openai", model_version="gpt-4"),
)
assert req.llm_override is not None
assert req.llm_overrides is None
def test_chat_message_detail_multi_model_fields_default_none() -> None:
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.USER,
time_sent=datetime.now(),
files=[],
)
assert detail.preferred_response_id is None
assert detail.model_display_name is None
def test_chat_message_detail_multi_model_fields_set() -> None:
detail = ChatMessageDetail(
message_id=1,
message="response from gpt-4",
message_type=MessageType.ASSISTANT,
time_sent=datetime.now(),
files=[],
preferred_response_id=42,
model_display_name="GPT-4",
)
assert detail.preferred_response_id == 42
assert detail.model_display_name == "GPT-4"
data = detail.model_dump()
assert data["preferred_response_id"] == 42
assert data["model_display_name"] == "GPT-4"