Compare commits

...

1 Commits

4 changed files with 276 additions and 42 deletions

View File

@@ -10,11 +10,9 @@ from uuid import UUID
from agents import Model
from agents import ModelSettings
from agents.models.openai_responses import OpenAIResponsesModel
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.agents.agent_sdk.message_format import base_messages_to_agent_sdk_msgs
from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona
@@ -91,6 +89,7 @@ from onyx.llm.factory import get_llm_model_and_settings_for_persona
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
from onyx.llm.interfaces import LLM
from onyx.llm.message_format import base_messages_to_chat_completion_msgs
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.natural_language_processing.utils import get_tokenizer
@@ -934,24 +933,19 @@ def _fast_message_stream(
model_settings: ModelSettings,
user_or_none: User | None,
) -> Generator[Packet, None, None]:
# TODO: clean up this jank
is_responses_api = isinstance(llm_model, OpenAIResponsesModel)
prompt_builder = answer.graph_inputs.prompt_builder
primary_llm = answer.graph_tooling.primary_llm
if prompt_builder and primary_llm:
_reserve_prompt_tokens_for_agent_overhead(
prompt_builder, primary_llm, tools, prompt_config
)
messages = base_messages_to_agent_sdk_msgs(
answer.graph_inputs.prompt_builder.build(), is_responses_api=is_responses_api
messages = base_messages_to_chat_completion_msgs(
answer.graph_inputs.prompt_builder.build()
)
emitter = get_default_emitter()
return fast_chat_turn.fast_chat_turn(
messages=messages,
# TODO: Maybe we can use some DI framework here?
dependencies=ChatTurnDependencies(
llm_model=llm_model,
model_settings=model_settings,
llm=answer.graph_tooling.primary_llm,
tools=tools,
db_session=db_session,
@@ -959,6 +953,7 @@ def _fast_message_stream(
emitter=emitter,
user_or_none=user_or_none,
prompt_config=prompt_config,
messages=messages,
),
chat_session_id=chat_session_id,
message_id=reserved_message_id,

View File

@@ -1,5 +1,6 @@
from collections.abc import Sequence
from dataclasses import replace
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from uuid import UUID
@@ -10,6 +11,9 @@ from agents import RunResultStreaming
from agents import ToolCallItem
from agents.tracing import trace
from onyx.agents.agent_framework.models import ModelResponseStream
from onyx.agents.agent_framework.models import RunItemStreamEvent
from onyx.agents.agent_framework.query import query
from onyx.agents.agent_sdk.message_types import AgentSDKMessage
from onyx.agents.agent_sdk.message_types import InputTextContent
from onyx.agents.agent_sdk.message_types import SystemMessage
@@ -42,6 +46,9 @@ from onyx.chat.turn.models import ChatTurnDependencies
from onyx.chat.turn.prompts.custom_instruction import build_custom_instructions
from onyx.chat.turn.save_turn import extract_final_answer_from_packets
from onyx.chat.turn.save_turn import save_turn
from onyx.llm.message_types import AssistantMessage
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import ToolMessage
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import CitationStart
@@ -125,6 +132,7 @@ def _run_agent_loop(
if iteration_count == 0 and force_use_tool
else None
) or "auto"
model_settings = replace(dependencies.model_settings, tool_choice=tool_choice)
agent = Agent(
@@ -190,8 +198,253 @@ def _run_agent_loop(
iteration_count += 1
def _run_agent_loop_v2(
messages: list[ChatCompletionMessage],
dependencies: ChatTurnDependencies,
chat_session_id: UUID,
ctx: ChatTurnContext,
prompt_config: PromptConfig,
force_use_tool: ForceUseTool | None = None,
) -> None:
"""Refactored agent loop using new query interface and ChatCompletionMessage format.
This version dramatically simplifies packet processing by using the query() function
which handles reasoning start/stop events automatically.
"""
# This should have already been called, but call it again here for good measure.
from onyx.llm.litellm_singleton.config import initialize_litellm
initialize_litellm()
chat_history = messages[1:-1]
current_user_message = messages[-1]
agent_turn_messages: list[ChatCompletionMessage] = []
last_call_is_final = False
iteration_count = 0
while not last_call_is_final:
available_tools: Sequence[Tool] = (
dependencies.tools if iteration_count < MAX_ITERATIONS else []
)
memories = get_memories(dependencies.user_or_none, dependencies.db_session)
# Build system message
langchain_system_message = default_build_system_message_v2(
dependencies.prompt_config,
dependencies.llm.config,
memories,
available_tools,
ctx.should_cite_documents,
)
new_system_message: ChatCompletionMessage = {
"role": "system",
"content": str(langchain_system_message.content),
}
# Build custom instructions as user messages
custom_instructions: list[ChatCompletionMessage] = []
if prompt_config.custom_instructions:
custom_instructions.append(
{
"role": "user",
"content": f"Custom Instructions: {prompt_config.custom_instructions}",
}
)
# Construct full message list
previous_messages = (
[new_system_message]
+ chat_history
+ custom_instructions
+ [current_user_message]
)
current_messages = previous_messages + agent_turn_messages
# Determine tool choice
if not available_tools:
tool_choice = None
else:
tool_choice = (
force_use_tool_to_function_tool_names(force_use_tool, available_tools)
if iteration_count == 0 and force_use_tool
else None
) or "auto"
# Process the stream from query()
assistant_content = ""
tool_calls_dict: dict[str, dict[str, Any]] = {}
tool_call_outputs: dict[str, str] = {}
reasoning_content = ""
for event in query(
llm_with_default_settings=dependencies.llm,
messages=current_messages,
tools=available_tools,
context=ctx,
tool_choice=tool_choice,
):
# Check for cancellation
connected = is_connected(
chat_session_id,
dependencies.redis_client,
)
if not connected:
_emit_clean_up_packets(dependencies, ctx)
break
# Process the event
if isinstance(event, RunItemStreamEvent):
# Handle structured events
if event.type == "reasoning_start":
dependencies.emitter.emit(
Packet(ind=ctx.current_run_step, obj=ReasoningStart())
)
reasoning_content = ""
elif event.type == "reasoning_done":
ctx.current_run_step += 1
dependencies.emitter.emit(
Packet(
ind=ctx.current_run_step, obj=SectionEnd(type="section_end")
)
)
elif event.type == "message_start":
llm_docs_for_message_start = llm_docs_from_fetched_documents_cache(
ctx.fetched_documents_cache
)
retrieved_search_docs = saved_search_docs_from_llm_docs(
llm_docs_for_message_start
)
ctx.current_run_step += 1
dependencies.emitter.emit(
Packet(
ind=ctx.current_run_step,
obj=MessageStart(
content="", final_documents=retrieved_search_docs
),
)
)
elif event.type == "message_done":
dependencies.emitter.emit(
Packet(
ind=ctx.current_run_step, obj=SectionEnd(type="section_end")
)
)
elif event.type == "tool_call" and event.details:
tool_call_item = event.details
if tool_call_item.call_id and tool_call_item.name:
tool_calls_dict[tool_call_item.call_id] = {
"id": tool_call_item.call_id,
"type": "function",
"function": {
"name": tool_call_item.name,
"arguments": tool_call_item.arguments or "",
},
}
elif event.type == "tool_call_output" and event.details:
output_item = event.details
if output_item.call_id:
tool_call_outputs[output_item.call_id] = str(output_item.output)
elif isinstance(event, ModelResponseStream):
# Handle raw model response chunks
delta = event.choice.delta
if delta.reasoning_content:
reasoning_content += delta.reasoning_content
dependencies.emitter.emit(
Packet(
ind=ctx.current_run_step,
obj=ReasoningDelta(reasoning=delta.reasoning_content),
)
)
if delta.content:
# Process content through citation processor if available
llm_docs = llm_docs_from_fetched_documents_cache(
ctx.fetched_documents_cache
)
if llm_docs:
mapping = map_document_id_order_v2(llm_docs)
processor = CitationProcessor(
context_docs=llm_docs,
doc_id_to_rank_map=mapping,
stop_stream=None,
)
final_answer_piece = ""
for response_part in processor.process_token(delta.content):
if isinstance(response_part, CitationInfo):
ctx.citations.append(response_part)
else:
final_answer_piece += response_part.answer_piece or ""
assistant_content += final_answer_piece
dependencies.emitter.emit(
Packet(
ind=ctx.current_run_step,
obj=MessageDelta(content=final_answer_piece),
)
)
else:
assistant_content += delta.content
dependencies.emitter.emit(
Packet(
ind=ctx.current_run_step,
obj=MessageDelta(content=delta.content),
)
)
# Build new messages from this iteration
new_messages: list[ChatCompletionMessage] = []
# Add assistant message (either with content or tool calls)
if tool_calls_dict:
# Assistant made tool calls
tool_calls_list = list(tool_calls_dict.values())
assistant_msg: AssistantMessage = {
"role": "assistant",
"content": assistant_content if assistant_content else None,
"tool_calls": tool_calls_list,
}
new_messages.append(assistant_msg)
# Add tool response messages
for tool_call in tool_calls_list:
call_id = tool_call["id"]
if call_id in tool_call_outputs:
tool_msg: ToolMessage = {
"role": "tool",
"content": tool_call_outputs[call_id],
"tool_call_id": call_id,
}
new_messages.append(tool_msg)
elif assistant_content:
# Assistant responded with text
assistant_msg: AssistantMessage = {
"role": "assistant",
"content": assistant_content,
}
new_messages.append(assistant_msg)
# Apply context handlers
# TODO: Port context handlers to work with ChatCompletionMessage format
# For now, track iteration metrics
agent_turn_messages.extend(new_messages)
# Determine if we should continue
# TODO: Make this configurable on OnyxAgent level
stopping_tools = ["image_generation"]
if not tool_calls_dict or any(
tc["function"]["name"] in stopping_tools for tc in tool_calls_dict.values()
):
last_call_is_final = True
iteration_count += 1
def _fast_chat_turn_core(
messages: list[AgentSDKMessage],
messages: list[ChatCompletionMessage],
dependencies: ChatTurnDependencies,
chat_session_id: UUID,
message_id: int,
@@ -204,13 +457,12 @@ def _fast_chat_turn_core(
"""Core fast chat turn logic that allows overriding global_iteration_responses for testing.
Args:
messages: List of chat messages
messages: List of chat messages in ChatCompletionMessage format
dependencies: Chat turn dependencies
chat_session_id: Chat session ID
message_id: Message ID
research_type: Research type
global_iteration_responses: Optional list of iteration answers to inject for testing
cited_documents: Optional list of cited documents to inject for testing
starter_context: Optional starter context for testing
"""
reset_cancel_status(
chat_session_id,
@@ -223,7 +475,7 @@ def _fast_chat_turn_core(
research_type=research_type,
)
with trace("fast_chat_turn"):
_run_agent_loop(
_run_agent_loop_v2(
messages=messages,
dependencies=dependencies,
chat_session_id=chat_session_id,
@@ -270,7 +522,7 @@ def _fast_chat_turn_core(
@unified_event_stream
def fast_chat_turn(
messages: list[AgentSDKMessage],
messages: list[ChatCompletionMessage],
dependencies: ChatTurnDependencies,
chat_session_id: UUID,
message_id: int,
@@ -278,7 +530,6 @@ def fast_chat_turn(
prompt_config: PromptConfig,
force_use_tool: ForceUseTool | None = None,
) -> None:
"""Main fast chat turn function that calls the core logic with default parameters."""
_fast_chat_turn_core(
messages,
dependencies,

View File

@@ -10,8 +10,6 @@ from agents import FunctionTool
from agents import HostedMCPTool
from agents import ImageGenerationTool as AgentsImageGenerationTool
from agents import LocalShellTool
from agents import Model
from agents import ModelSettings
from agents import WebSearchTool
from pydantic import BaseModel
from redis.client import Redis
@@ -25,6 +23,7 @@ from onyx.chat.turn.infra.emitter import Emitter
from onyx.context.search.models import InferenceSection
from onyx.db.models import User
from onyx.llm.interfaces import LLM
from onyx.llm.message_types import ChatCompletionMessage
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.tools.tool import Tool
@@ -43,9 +42,6 @@ AgentToolType = (
@dataclass
class ChatTurnDependencies:
llm_model: Model
model_settings: ModelSettings
# TODO we can delete this field (combine them)
llm: LLM
db_session: Session
tools: Sequence[Tool]
@@ -53,6 +49,7 @@ class ChatTurnDependencies:
emitter: Emitter
user_or_none: User | None
prompt_config: PromptConfig
messages: Sequence[ChatCompletionMessage]
class FetchedDocumentCacheEntry(BaseModel):

View File

@@ -31,12 +31,13 @@ from openai.types.responses.response_stream_event import ResponseTextDeltaEvent
from onyx.agents.agent_sdk.message_types import AgentSDKMessage
from onyx.agents.agent_sdk.message_types import AssistantMessageWithContent
from onyx.agents.agent_sdk.message_types import InputTextContent
from onyx.agents.agent_sdk.message_types import SystemMessage
from onyx.agents.agent_sdk.message_types import UserMessage
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.chat.models import PromptConfig
from onyx.chat.turn.models import ChatTurnContext
from onyx.chat.turn.models import ChatTurnDependencies
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import SystemMessage
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import OverallStop
@@ -92,7 +93,7 @@ class CancellationMixin:
def run_fast_chat_turn(
sample_messages: list[AgentSDKMessage],
sample_messages: list[ChatCompletionMessage],
chat_turn_dependencies: ChatTurnDependencies,
chat_session_id: UUID,
message_id: int,
@@ -361,26 +362,16 @@ def fake_tool_call_model() -> Model:
@pytest.fixture
def sample_messages() -> list[AgentSDKMessage]:
def sample_messages() -> list[ChatCompletionMessage]:
return [
SystemMessage(
role="system",
content=[
InputTextContent(
type="input_text",
text="You are a highly capable assistant",
)
],
),
UserMessage(
role="user",
content=[
InputTextContent(
type="input_text",
text="hi",
)
],
),
{
"role": "system",
"content": "You are a highly capable assistant",
},
{
"role": "user",
"content": "hi",
},
]