mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-09 08:52:42 +00:00
Compare commits
1 Commits
edge
...
richard/sp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
44a87dae35 |
@@ -10,11 +10,9 @@ from uuid import UUID
|
||||
|
||||
from agents import Model
|
||||
from agents import ModelSettings
|
||||
from agents.models.openai_responses import OpenAIResponsesModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_sdk.message_format import base_messages_to_agent_sdk_msgs
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
@@ -91,6 +89,7 @@ from onyx.llm.factory import get_llm_model_and_settings_for_persona
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.message_format import base_messages_to_chat_completion_msgs
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
@@ -934,24 +933,19 @@ def _fast_message_stream(
|
||||
model_settings: ModelSettings,
|
||||
user_or_none: User | None,
|
||||
) -> Generator[Packet, None, None]:
|
||||
# TODO: clean up this jank
|
||||
is_responses_api = isinstance(llm_model, OpenAIResponsesModel)
|
||||
prompt_builder = answer.graph_inputs.prompt_builder
|
||||
primary_llm = answer.graph_tooling.primary_llm
|
||||
if prompt_builder and primary_llm:
|
||||
_reserve_prompt_tokens_for_agent_overhead(
|
||||
prompt_builder, primary_llm, tools, prompt_config
|
||||
)
|
||||
messages = base_messages_to_agent_sdk_msgs(
|
||||
answer.graph_inputs.prompt_builder.build(), is_responses_api=is_responses_api
|
||||
messages = base_messages_to_chat_completion_msgs(
|
||||
answer.graph_inputs.prompt_builder.build()
|
||||
)
|
||||
emitter = get_default_emitter()
|
||||
return fast_chat_turn.fast_chat_turn(
|
||||
messages=messages,
|
||||
# TODO: Maybe we can use some DI framework here?
|
||||
dependencies=ChatTurnDependencies(
|
||||
llm_model=llm_model,
|
||||
model_settings=model_settings,
|
||||
llm=answer.graph_tooling.primary_llm,
|
||||
tools=tools,
|
||||
db_session=db_session,
|
||||
@@ -959,6 +953,7 @@ def _fast_message_stream(
|
||||
emitter=emitter,
|
||||
user_or_none=user_or_none,
|
||||
prompt_config=prompt_config,
|
||||
messages=messages,
|
||||
),
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
@@ -10,6 +11,9 @@ from agents import RunResultStreaming
|
||||
from agents import ToolCallItem
|
||||
from agents.tracing import trace
|
||||
|
||||
from onyx.agents.agent_framework.models import ModelResponseStream
|
||||
from onyx.agents.agent_framework.models import RunItemStreamEvent
|
||||
from onyx.agents.agent_framework.query import query
|
||||
from onyx.agents.agent_sdk.message_types import AgentSDKMessage
|
||||
from onyx.agents.agent_sdk.message_types import InputTextContent
|
||||
from onyx.agents.agent_sdk.message_types import SystemMessage
|
||||
@@ -42,6 +46,9 @@ from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.chat.turn.prompts.custom_instruction import build_custom_instructions
|
||||
from onyx.chat.turn.save_turn import extract_final_answer_from_packets
|
||||
from onyx.chat.turn.save_turn import save_turn
|
||||
from onyx.llm.message_types import AssistantMessage
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.message_types import ToolMessage
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
@@ -125,6 +132,7 @@ def _run_agent_loop(
|
||||
if iteration_count == 0 and force_use_tool
|
||||
else None
|
||||
) or "auto"
|
||||
|
||||
model_settings = replace(dependencies.model_settings, tool_choice=tool_choice)
|
||||
|
||||
agent = Agent(
|
||||
@@ -190,8 +198,253 @@ def _run_agent_loop(
|
||||
iteration_count += 1
|
||||
|
||||
|
||||
def _run_agent_loop_v2(
|
||||
messages: list[ChatCompletionMessage],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
ctx: ChatTurnContext,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool | None = None,
|
||||
) -> None:
|
||||
"""Refactored agent loop using new query interface and ChatCompletionMessage format.
|
||||
|
||||
This version dramatically simplifies packet processing by using the query() function
|
||||
which handles reasoning start/stop events automatically.
|
||||
"""
|
||||
# This should have already been called, but call it again here for good measure.
|
||||
from onyx.llm.litellm_singleton.config import initialize_litellm
|
||||
|
||||
initialize_litellm()
|
||||
|
||||
chat_history = messages[1:-1]
|
||||
current_user_message = messages[-1]
|
||||
|
||||
agent_turn_messages: list[ChatCompletionMessage] = []
|
||||
last_call_is_final = False
|
||||
iteration_count = 0
|
||||
|
||||
while not last_call_is_final:
|
||||
available_tools: Sequence[Tool] = (
|
||||
dependencies.tools if iteration_count < MAX_ITERATIONS else []
|
||||
)
|
||||
|
||||
memories = get_memories(dependencies.user_or_none, dependencies.db_session)
|
||||
|
||||
# Build system message
|
||||
langchain_system_message = default_build_system_message_v2(
|
||||
dependencies.prompt_config,
|
||||
dependencies.llm.config,
|
||||
memories,
|
||||
available_tools,
|
||||
ctx.should_cite_documents,
|
||||
)
|
||||
|
||||
new_system_message: ChatCompletionMessage = {
|
||||
"role": "system",
|
||||
"content": str(langchain_system_message.content),
|
||||
}
|
||||
|
||||
# Build custom instructions as user messages
|
||||
custom_instructions: list[ChatCompletionMessage] = []
|
||||
if prompt_config.custom_instructions:
|
||||
custom_instructions.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Custom Instructions: {prompt_config.custom_instructions}",
|
||||
}
|
||||
)
|
||||
|
||||
# Construct full message list
|
||||
previous_messages = (
|
||||
[new_system_message]
|
||||
+ chat_history
|
||||
+ custom_instructions
|
||||
+ [current_user_message]
|
||||
)
|
||||
current_messages = previous_messages + agent_turn_messages
|
||||
|
||||
# Determine tool choice
|
||||
if not available_tools:
|
||||
tool_choice = None
|
||||
else:
|
||||
tool_choice = (
|
||||
force_use_tool_to_function_tool_names(force_use_tool, available_tools)
|
||||
if iteration_count == 0 and force_use_tool
|
||||
else None
|
||||
) or "auto"
|
||||
|
||||
# Process the stream from query()
|
||||
assistant_content = ""
|
||||
tool_calls_dict: dict[str, dict[str, Any]] = {}
|
||||
tool_call_outputs: dict[str, str] = {}
|
||||
reasoning_content = ""
|
||||
|
||||
for event in query(
|
||||
llm_with_default_settings=dependencies.llm,
|
||||
messages=current_messages,
|
||||
tools=available_tools,
|
||||
context=ctx,
|
||||
tool_choice=tool_choice,
|
||||
):
|
||||
# Check for cancellation
|
||||
connected = is_connected(
|
||||
chat_session_id,
|
||||
dependencies.redis_client,
|
||||
)
|
||||
if not connected:
|
||||
_emit_clean_up_packets(dependencies, ctx)
|
||||
break
|
||||
|
||||
# Process the event
|
||||
if isinstance(event, RunItemStreamEvent):
|
||||
# Handle structured events
|
||||
if event.type == "reasoning_start":
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=ctx.current_run_step, obj=ReasoningStart())
|
||||
)
|
||||
reasoning_content = ""
|
||||
elif event.type == "reasoning_done":
|
||||
ctx.current_run_step += 1
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=ctx.current_run_step, obj=SectionEnd(type="section_end")
|
||||
)
|
||||
)
|
||||
elif event.type == "message_start":
|
||||
llm_docs_for_message_start = llm_docs_from_fetched_documents_cache(
|
||||
ctx.fetched_documents_cache
|
||||
)
|
||||
retrieved_search_docs = saved_search_docs_from_llm_docs(
|
||||
llm_docs_for_message_start
|
||||
)
|
||||
ctx.current_run_step += 1
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=ctx.current_run_step,
|
||||
obj=MessageStart(
|
||||
content="", final_documents=retrieved_search_docs
|
||||
),
|
||||
)
|
||||
)
|
||||
elif event.type == "message_done":
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=ctx.current_run_step, obj=SectionEnd(type="section_end")
|
||||
)
|
||||
)
|
||||
elif event.type == "tool_call" and event.details:
|
||||
tool_call_item = event.details
|
||||
if tool_call_item.call_id and tool_call_item.name:
|
||||
tool_calls_dict[tool_call_item.call_id] = {
|
||||
"id": tool_call_item.call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call_item.name,
|
||||
"arguments": tool_call_item.arguments or "",
|
||||
},
|
||||
}
|
||||
elif event.type == "tool_call_output" and event.details:
|
||||
output_item = event.details
|
||||
if output_item.call_id:
|
||||
tool_call_outputs[output_item.call_id] = str(output_item.output)
|
||||
|
||||
elif isinstance(event, ModelResponseStream):
|
||||
# Handle raw model response chunks
|
||||
delta = event.choice.delta
|
||||
|
||||
if delta.reasoning_content:
|
||||
reasoning_content += delta.reasoning_content
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=ctx.current_run_step,
|
||||
obj=ReasoningDelta(reasoning=delta.reasoning_content),
|
||||
)
|
||||
)
|
||||
|
||||
if delta.content:
|
||||
# Process content through citation processor if available
|
||||
llm_docs = llm_docs_from_fetched_documents_cache(
|
||||
ctx.fetched_documents_cache
|
||||
)
|
||||
if llm_docs:
|
||||
mapping = map_document_id_order_v2(llm_docs)
|
||||
processor = CitationProcessor(
|
||||
context_docs=llm_docs,
|
||||
doc_id_to_rank_map=mapping,
|
||||
stop_stream=None,
|
||||
)
|
||||
final_answer_piece = ""
|
||||
for response_part in processor.process_token(delta.content):
|
||||
if isinstance(response_part, CitationInfo):
|
||||
ctx.citations.append(response_part)
|
||||
else:
|
||||
final_answer_piece += response_part.answer_piece or ""
|
||||
assistant_content += final_answer_piece
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=ctx.current_run_step,
|
||||
obj=MessageDelta(content=final_answer_piece),
|
||||
)
|
||||
)
|
||||
else:
|
||||
assistant_content += delta.content
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=ctx.current_run_step,
|
||||
obj=MessageDelta(content=delta.content),
|
||||
)
|
||||
)
|
||||
|
||||
# Build new messages from this iteration
|
||||
new_messages: list[ChatCompletionMessage] = []
|
||||
|
||||
# Add assistant message (either with content or tool calls)
|
||||
if tool_calls_dict:
|
||||
# Assistant made tool calls
|
||||
tool_calls_list = list(tool_calls_dict.values())
|
||||
assistant_msg: AssistantMessage = {
|
||||
"role": "assistant",
|
||||
"content": assistant_content if assistant_content else None,
|
||||
"tool_calls": tool_calls_list,
|
||||
}
|
||||
new_messages.append(assistant_msg)
|
||||
|
||||
# Add tool response messages
|
||||
for tool_call in tool_calls_list:
|
||||
call_id = tool_call["id"]
|
||||
if call_id in tool_call_outputs:
|
||||
tool_msg: ToolMessage = {
|
||||
"role": "tool",
|
||||
"content": tool_call_outputs[call_id],
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
new_messages.append(tool_msg)
|
||||
elif assistant_content:
|
||||
# Assistant responded with text
|
||||
assistant_msg: AssistantMessage = {
|
||||
"role": "assistant",
|
||||
"content": assistant_content,
|
||||
}
|
||||
new_messages.append(assistant_msg)
|
||||
|
||||
# Apply context handlers
|
||||
# TODO: Port context handlers to work with ChatCompletionMessage format
|
||||
# For now, track iteration metrics
|
||||
agent_turn_messages.extend(new_messages)
|
||||
|
||||
# Determine if we should continue
|
||||
# TODO: Make this configurable on OnyxAgent level
|
||||
stopping_tools = ["image_generation"]
|
||||
if not tool_calls_dict or any(
|
||||
tc["function"]["name"] in stopping_tools for tc in tool_calls_dict.values()
|
||||
):
|
||||
last_call_is_final = True
|
||||
|
||||
iteration_count += 1
|
||||
|
||||
|
||||
def _fast_chat_turn_core(
|
||||
messages: list[AgentSDKMessage],
|
||||
messages: list[ChatCompletionMessage],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
@@ -204,13 +457,12 @@ def _fast_chat_turn_core(
|
||||
"""Core fast chat turn logic that allows overriding global_iteration_responses for testing.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
messages: List of chat messages in ChatCompletionMessage format
|
||||
dependencies: Chat turn dependencies
|
||||
chat_session_id: Chat session ID
|
||||
message_id: Message ID
|
||||
research_type: Research type
|
||||
global_iteration_responses: Optional list of iteration answers to inject for testing
|
||||
cited_documents: Optional list of cited documents to inject for testing
|
||||
starter_context: Optional starter context for testing
|
||||
"""
|
||||
reset_cancel_status(
|
||||
chat_session_id,
|
||||
@@ -223,7 +475,7 @@ def _fast_chat_turn_core(
|
||||
research_type=research_type,
|
||||
)
|
||||
with trace("fast_chat_turn"):
|
||||
_run_agent_loop(
|
||||
_run_agent_loop_v2(
|
||||
messages=messages,
|
||||
dependencies=dependencies,
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -270,7 +522,7 @@ def _fast_chat_turn_core(
|
||||
|
||||
@unified_event_stream
|
||||
def fast_chat_turn(
|
||||
messages: list[AgentSDKMessage],
|
||||
messages: list[ChatCompletionMessage],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
@@ -278,7 +530,6 @@ def fast_chat_turn(
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool | None = None,
|
||||
) -> None:
|
||||
"""Main fast chat turn function that calls the core logic with default parameters."""
|
||||
_fast_chat_turn_core(
|
||||
messages,
|
||||
dependencies,
|
||||
|
||||
@@ -10,8 +10,6 @@ from agents import FunctionTool
|
||||
from agents import HostedMCPTool
|
||||
from agents import ImageGenerationTool as AgentsImageGenerationTool
|
||||
from agents import LocalShellTool
|
||||
from agents import Model
|
||||
from agents import ModelSettings
|
||||
from agents import WebSearchTool
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
@@ -25,6 +23,7 @@ from onyx.chat.turn.infra.emitter import Emitter
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
@@ -43,9 +42,6 @@ AgentToolType = (
|
||||
|
||||
@dataclass
|
||||
class ChatTurnDependencies:
|
||||
llm_model: Model
|
||||
model_settings: ModelSettings
|
||||
# TODO we can delete this field (combine them)
|
||||
llm: LLM
|
||||
db_session: Session
|
||||
tools: Sequence[Tool]
|
||||
@@ -53,6 +49,7 @@ class ChatTurnDependencies:
|
||||
emitter: Emitter
|
||||
user_or_none: User | None
|
||||
prompt_config: PromptConfig
|
||||
messages: Sequence[ChatCompletionMessage]
|
||||
|
||||
|
||||
class FetchedDocumentCacheEntry(BaseModel):
|
||||
|
||||
@@ -31,12 +31,13 @@ from openai.types.responses.response_stream_event import ResponseTextDeltaEvent
|
||||
from onyx.agents.agent_sdk.message_types import AgentSDKMessage
|
||||
from onyx.agents.agent_sdk.message_types import AssistantMessageWithContent
|
||||
from onyx.agents.agent_sdk.message_types import InputTextContent
|
||||
from onyx.agents.agent_sdk.message_types import SystemMessage
|
||||
from onyx.agents.agent_sdk.message_types import UserMessage
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.message_types import SystemMessage
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
@@ -92,7 +93,7 @@ class CancellationMixin:
|
||||
|
||||
|
||||
def run_fast_chat_turn(
|
||||
sample_messages: list[AgentSDKMessage],
|
||||
sample_messages: list[ChatCompletionMessage],
|
||||
chat_turn_dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
@@ -361,26 +362,16 @@ def fake_tool_call_model() -> Model:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages() -> list[AgentSDKMessage]:
|
||||
def sample_messages() -> list[ChatCompletionMessage]:
|
||||
return [
|
||||
SystemMessage(
|
||||
role="system",
|
||||
content=[
|
||||
InputTextContent(
|
||||
type="input_text",
|
||||
text="You are a highly capable assistant",
|
||||
)
|
||||
],
|
||||
),
|
||||
UserMessage(
|
||||
role="user",
|
||||
content=[
|
||||
InputTextContent(
|
||||
type="input_text",
|
||||
text="hi",
|
||||
)
|
||||
],
|
||||
),
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a highly capable assistant",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user