Compare commits

...

3 Commits

Author SHA1 Message Date
Richard Guan
764e7f44f6 Merge branch 'main' into richard/selective-migration 2025-12-01 10:08:39 -08:00
Richard Guan
1c376c66d4 . 2025-11-25 17:55:59 -08:00
Richard Guan
735b3c4c02 . 2025-11-25 17:44:14 -08:00
35 changed files with 3790 additions and 180 deletions

View File

@@ -1,4 +1,5 @@
import json
import traceback
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
@@ -312,11 +313,15 @@ def query(
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": tool.name, "error": str(e)},
data={
"tool_name": tool.name,
"error": str(e),
"stack_trace": traceback.format_exc(),
},
)
)
# Treat the error as the tool output so the framework can continue
error_output = f"Error: {str(e)}"
error_output = tool.failure_error_function(e)
tool_outputs[call_id] = error_output
output = error_output
@@ -328,8 +333,17 @@ def query(
),
)
else:
not_found_output = f"Tool {name} not found"
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
with function_span(f"tool_not_found_{name}") as span_fn:
not_found_output = {
"error": True,
"error_type": "TOOL_NOT FOUND",
"message": f"The tool {name} does not exist or is not registered.",
"available_tools": [tool.name for tool in tools],
}
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
span_fn.span_data.input = arguments_str
span_fn.span_data.output = not_found_output
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(

View File

@@ -56,7 +56,6 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.adapter_v1_to_v2 import force_use_tool_to_function_tool_names
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
from onyx.tools.force import filter_tools_for_force_tool_use
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
@@ -111,10 +110,6 @@ def _run_agent_loop(
available_tools: Sequence[Tool] = (
dependencies.tools if iteration_count < MAX_ITERATIONS else []
)
if force_use_tool and force_use_tool.force_use:
available_tools = filter_tools_for_force_tool_use(
list(available_tools), force_use_tool
)
memories = get_memories(dependencies.user_or_none, dependencies.db_session)
# TODO: The system is rather prompt-cache efficient except for rebuilding the system prompt.
# The biggest offender is when we hit max iterations and then all the tool calls cannot

View File

@@ -47,6 +47,7 @@ from onyx.llm.llm_provider_options import VERTEX_CREDENTIALS_FILE_KWARG
from onyx.llm.llm_provider_options import VERTEX_LOCATION_KWARG
from onyx.llm.model_response import ModelResponse
from onyx.llm.model_response import ModelResponseStream
from onyx.llm.utils import is_anthropic_model
from onyx.llm.utils import is_true_openai_model
from onyx.llm.utils import model_is_reasoning_model
from onyx.server.utils import mask_string
@@ -466,15 +467,22 @@ class LitellmLLM(LLM):
from litellm.exceptions import Timeout, RateLimitError
tool_choice_formatted: dict[str, Any] | str | None
function_tool_choice = (
tool_choice and tool_choice not in STANDARD_TOOL_CHOICE_OPTIONS
)
if not tools:
tool_choice_formatted = None
elif tool_choice and tool_choice not in STANDARD_TOOL_CHOICE_OPTIONS:
elif function_tool_choice:
tool_choice_formatted = {
"type": "function",
"function": {"name": tool_choice},
}
else:
tool_choice_formatted = tool_choice
exclude_reasoning_config_for_anthropic = (
is_anthropic_model(self.config.model_name, self.config.model_provider)
and function_tool_choice
)
is_reasoning = model_is_reasoning_model(
self.config.model_name, self.config.model_provider
@@ -531,12 +539,16 @@ class LitellmLLM(LLM):
),
**(
{"thinking": {"type": "enabled", "budget_tokens": 10000}}
if reasoning_effort and is_reasoning
if reasoning_effort
and is_reasoning
and not exclude_reasoning_config_for_anthropic
else {}
),
**(
{"reasoning_effort": reasoning_effort}
if reasoning_effort and is_reasoning
if reasoning_effort
and is_reasoning
and not exclude_reasoning_config_for_anthropic
else {}
),
**(

View File

@@ -7,41 +7,28 @@ from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypedDict
from typing import Union
from litellm import AllMessageValues
from litellm import LiteLLMLoggingObj
from litellm.completion_extras.litellm_responses_transformation.transformation import (
LiteLLMResponsesTransformationHandler,
)
from litellm.completion_extras.litellm_responses_transformation.transformation import (
OpenAiResponsesToChatCompletionStreamIterator,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
try:
from litellm.litellm_core_utils.prompt_templates.common_utils import (
extract_images_from_message,
)
except ImportError:
extract_images_from_message = None # type: ignore[assignment]
from litellm.litellm_core_utils.prompt_templates.common_utils import (
extract_images_from_message,
)
from litellm.llms.ollama.chat.transformation import OllamaChatCompletionResponseIterator
from litellm.llms.ollama.chat.transformation import OllamaChatConfig
from litellm.llms.ollama.common_utils import OllamaError
try:
from litellm.types.llms.ollama import OllamaChatCompletionMessage
except ImportError:
class OllamaChatCompletionMessage(TypedDict, total=False): # type: ignore[no-redef]
"""Fallback for LiteLLM versions where this TypedDict was removed."""
role: str
content: Optional[str]
images: Optional[List[Any]]
thinking: Optional[str]
tool_calls: Optional[List["OllamaToolCall"]]
from litellm.llms.openai.chat.gpt_5_transformation import OpenAIGPT5Config
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.llms.ollama import OllamaChatCompletionMessage
from litellm.types.llms.ollama import OllamaToolCall
from litellm.types.llms.ollama import OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
@@ -52,43 +39,7 @@ from litellm.utils import verbose_logger
from pydantic import BaseModel
if extract_images_from_message is None:
def extract_images_from_message(
message: AllMessageValues,
) -> Optional[List[Any]]:
"""Fallback for LiteLLM versions that dropped extract_images_from_message."""
images: List[Any] = []
content = message.get("content")
if not isinstance(content, list):
return None
for item in content:
if not isinstance(item, Dict):
continue
item_type = item.get("type")
if item_type == "image_url":
image_url = item.get("image_url")
if isinstance(image_url, dict):
if image_url.get("url"):
images.append(image_url)
elif image_url:
images.append(image_url)
elif item_type in {"input_image", "image"}:
image_value = item.get("image")
if image_value:
images.append(image_value)
return images or None
def _patch_ollama_transform_request() -> None:
"""
Patches OllamaChatConfig.transform_request to handle reasoning content
and tool calls properly for Ollama chat completions.
"""
def _patch_ollama_transform_request_so_tool_calls_streamed() -> None:
if (
getattr(OllamaChatConfig.transform_request, "__name__", "")
== "_patched_transform_request"
@@ -180,11 +131,7 @@ def _patch_ollama_transform_request() -> None:
OllamaChatConfig.transform_request = _patched_transform_request # type: ignore[method-assign]
def _patch_ollama_chunk_parser() -> None:
"""
Patches OllamaChatCompletionResponseIterator.chunk_parser to properly handle
reasoning content and content in streaming responses.
"""
def _patch_ollama_chunk_parser_so_reasoning_streamed() -> None:
if (
getattr(OllamaChatCompletionResponseIterator.chunk_parser, "__name__", "")
== "_patched_chunk_parser"
@@ -312,11 +259,7 @@ def _patch_ollama_chunk_parser() -> None:
OllamaChatCompletionResponseIterator.chunk_parser = _patched_chunk_parser # type: ignore[method-assign]
def _patch_openai_responses_chunk_parser() -> None:
"""
Patches OpenAiResponsesToChatCompletionStreamIterator.chunk_parser to properly
handle OpenAI Responses API streaming format and convert it to chat completion format.
"""
def _patch_openai_responses_chunk_parser_so_reasoning_streamed() -> None:
if (
getattr(
OpenAiResponsesToChatCompletionStreamIterator.chunk_parser,
@@ -483,6 +426,174 @@ def _patch_openai_responses_chunk_parser() -> None:
OpenAiResponsesToChatCompletionStreamIterator.chunk_parser = _patched_openai_responses_chunk_parser # type: ignore[method-assign]
def _patch_litellm_responses_transformation_handler_so_tool_choice_formatted() -> None:
if (
getattr(
LiteLLMResponsesTransformationHandler.transform_request,
"__name__",
"",
)
== "_patched_transform_request"
):
return
def _patched_transform_request(
self,
model: str,
messages: List["AllMessageValues"],
optional_params: dict,
litellm_params: dict,
headers: dict,
litellm_logging_obj: "LiteLLMLoggingObj",
client: Optional[Any] = None,
) -> dict:
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
(
input_items,
instructions,
) = self.convert_chat_completion_messages_to_responses_api(messages)
# Build responses API request using the reverse transformation logic
responses_api_request = ResponsesAPIOptionalRequestParams()
# Set instructions if we found a system message
if instructions:
responses_api_request["instructions"] = instructions
# Map optional parameters
for key, value in optional_params.items():
if value is None:
continue
if key in ("max_tokens", "max_completion_tokens"):
responses_api_request["max_output_tokens"] = value
elif key == "tools" and value is not None:
# Convert chat completion tools to responses API tools format
responses_api_request["tools"] = (
LiteLLMResponsesTransformationHandler._convert_tools_to_responses_format(
None, cast(List[Dict[str, Any]], value)
)
)
elif key == "tool_choice" and value is not None:
# TODO Right now, we're only supporting function tools and tool choice
# mode. In reality this should support all the possible tool_choice inputs
# documented in the API docs, but these are the only formats Onyx uses.
responses_api_request["tool_choice"] = (
{
"type": "function",
"name": (
value
if isinstance(value, str)
else value["function"]["name"]
),
}
if isinstance(value, dict)
else value
)
elif key in ResponsesAPIOptionalRequestParams.__annotations__.keys():
responses_api_request[key] = value # type: ignore
elif key in ("metadata"):
responses_api_request["metadata"] = value
elif key in ("previous_response_id"):
responses_api_request["previous_response_id"] = value
elif key == "reasoning_effort":
responses_api_request["reasoning"] = self._map_reasoning_effort(value)
# Get stream parameter from litellm_params if not in optional_params
stream = optional_params.get("stream") or litellm_params.get("stream", False)
verbose_logger.debug(f"Chat provider: Stream parameter: {stream}")
# Ensure stream is properly set in the request
if stream:
responses_api_request["stream"] = True
# Handle session management if previous_response_id is provided
previous_response_id = optional_params.get("previous_response_id")
if previous_response_id:
# Use the existing session handler for responses API
verbose_logger.debug(
f"Chat provider: Warning ignoring previous response ID: {previous_response_id}"
)
# Convert back to responses API format for the actual request
api_model = model
from litellm.types.utils import CallTypes
setattr(litellm_logging_obj, "call_type", CallTypes.responses.value)
request_data = {
"model": api_model,
"input": input_items,
"litellm_logging_obj": litellm_logging_obj,
**litellm_params,
"client": client,
}
verbose_logger.debug(
f"Chat provider: Final request model={api_model}, input_items={len(input_items)}"
)
# Add non-None values from responses_api_request
for key, value in responses_api_request.items():
if value is not None:
if key == "instructions" and instructions:
request_data["instructions"] = instructions
elif key == "stream_options" and isinstance(value, dict):
request_data["stream_options"] = value.get("include_obfuscation")
elif key == "user": # string can't be longer than 64 characters
if isinstance(value, str) and len(value) <= 64:
request_data["user"] = value
else:
request_data[key] = value
return request_data
_patched_transform_request.__name__ = "_patched_transform_request"
LiteLLMResponsesTransformationHandler.transform_request = _patched_transform_request # type: ignore[method-assign]
def _patch_litellm_openai_gpt5_coonfig_allow_tool_choice_for_responses_bridge() -> None:
if (
getattr(
OpenAIGPT5Config.get_supported_openai_params,
"__name__",
"",
)
== "_patched_get_supported_openai_params"
):
return
def _patched_get_supported_openai_params(self, model: str) -> list:
# Call the parent class method directly (can't use super() in monkey patches)
base_gpt_series_params = OpenAIGPTConfig.get_supported_openai_params(
self, model=model
)
gpt_5_only_params = ["reasoning_effort"]
base_gpt_series_params.extend(gpt_5_only_params)
non_supported_params = [
"logprobs",
"top_p",
"presence_penalty",
"frequency_penalty",
"top_logprobs",
"stop",
]
return [
param
for param in base_gpt_series_params
if param not in non_supported_params
]
_patched_get_supported_openai_params.__name__ = (
"_patched_get_supported_openai_params"
)
OpenAIGPT5Config.get_supported_openai_params = _patched_get_supported_openai_params # type: ignore[method-assign]
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for compatibility.
@@ -491,10 +602,14 @@ def apply_monkey_patches() -> None:
- Patching OllamaChatConfig.transform_request for reasoning content support
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
- Patching LiteLLMResponsesTransformationHandler.transform_request for tool choice formatting
- Patching OpenAIGPT5Config.get_supported_openai_params for tool choice support for responses bridge
"""
_patch_ollama_transform_request()
_patch_ollama_chunk_parser()
_patch_openai_responses_chunk_parser()
_patch_ollama_transform_request_so_tool_calls_streamed()
_patch_ollama_chunk_parser_so_reasoning_streamed()
_patch_openai_responses_chunk_parser_so_reasoning_streamed()
_patch_litellm_responses_transformation_handler_so_tool_choice_formatted()
_patch_litellm_openai_gpt5_coonfig_allow_tool_choice_for_responses_bridge()
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:

View File

@@ -935,6 +935,10 @@ def is_true_openai_model(model_provider: str, model_name: str) -> bool:
return False
def is_anthropic_model(model_name: str, model_provider: str) -> bool:
return model_provider == "anthropic" or "claude" in model_name.lower()
def model_needs_formatting_reenabled(model_name: str) -> bool:
# See https://simonwillison.net/tags/markdown/ for context on why this is needed
# for OpenAI reasoning models to have correct markdown generation

View File

@@ -34,12 +34,8 @@ from onyx.server.features.web_search.models import WebSearchToolResponse
from onyx.server.features.web_search.models import WebSearchWithContentResponse
from onyx.server.manage.web_search.models import WebContentProviderView
from onyx.server.manage.web_search.models import WebSearchProviderView
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmOpenUrlResult,
)
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmWebSearchResult,
)
from onyx.tools.tool_result_models import LlmOpenUrlResult
from onyx.tools.tool_result_models import LlmWebSearchResult
from onyx.utils.logger import setup_logger
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType

View File

@@ -2,12 +2,8 @@ from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmOpenUrlResult,
)
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmWebSearchResult,
)
from onyx.tools.tool_result_models import LlmOpenUrlResult
from onyx.tools.tool_result_models import LlmWebSearchResult
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType

View File

@@ -13,12 +13,12 @@ from onyx.chat.turn.models import ChatTurnContext
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.built_in_tools_v2 import BUILT_IN_TOOL_MAP_V2
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.mcp.mcp_tool import MCPTool
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
# Type alias for tools that need custom handling
CustomOrMcpTool = Union[CustomTool, MCPTool]
@@ -29,7 +29,6 @@ def is_custom_or_mcp_tool(tool: Tool) -> bool:
return isinstance(tool, CustomTool) or isinstance(tool, MCPTool)
@tool_accounting
async def _tool_run_wrapper(
run_context: RunContextWrapper[ChatTurnContext], tool: Tool, json_string: str
) -> list[Any]:
@@ -37,7 +36,11 @@ async def _tool_run_wrapper(
Wrapper function to adapt Tool.run() to FunctionTool.on_invoke_tool() signature.
"""
args = json.loads(json_string) if json_string else {}
# Manually handle tool accounting (increment step)
run_context.context.current_run_step += 1
index = run_context.context.current_run_step
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
@@ -95,6 +98,16 @@ async def _tool_run_wrapper(
),
)
)
# Emit section end and increment step (manually handle tool accounting cleanup)
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SectionEnd(type="section_end"),
)
)
run_context.context.current_run_step += 1
return results

View File

@@ -42,7 +42,7 @@ class BaseTool(Tool[None]):
run_context: RunContextWrapper[Any],
*args: Any,
**kwargs: Any,
) -> Any:
) -> str: # we expect JSON format returned to the model
raise NotImplementedError("BaseTool.run_v2 is not implemented.")
def build_next_prompt(

View File

@@ -2,8 +2,6 @@ from typing import Any
from pydantic import BaseModel
from onyx.tools.tool import Tool
class ForceUseTool(BaseModel):
# Could be not a forced usage of the tool but still have args, in which case
@@ -17,12 +15,3 @@ class ForceUseTool(BaseModel):
def build_openai_tool_choice_dict(self) -> dict[str, Any]:
"""Build dict in the format that OpenAI expects which tells them to use this tool."""
return {"type": "function", "name": self.tool_name}
def filter_tools_for_force_tool_use(
tools: list[Tool], force_use_tool: ForceUseTool
) -> list[Tool]:
if not force_use_tool.force_use:
return tools
return [tool for tool in tools if tool.name == force_use_tool.tool_name]

View File

@@ -100,6 +100,8 @@ class Tool(abc.ABC, Generic[OVERRIDE_T]):
"""Actual execution of the tool"""
# run_V2 should be what's used moving forwards.
# run is only there to support deep research.
@abc.abstractmethod
def run_v2(
self,
@@ -109,6 +111,15 @@ class Tool(abc.ABC, Generic[OVERRIDE_T]):
) -> Any:
raise NotImplementedError
def failure_error_function(self, error: Exception) -> str:
"""
This function defines what is returned to the LLM when the tool fails.
By default, it returns a generic error message.
Subclasses may override to provide a more specific error message, or re-raise the error
for a hard error in the framework.
"""
return f"Error: {str(error)}"
@abc.abstractmethod
def run(
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any

View File

@@ -54,6 +54,7 @@ from onyx.tools.tool_implementations.python.python_tool import (
PythonTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.open_url_tool import OpenUrlTool
from onyx.tools.tool_implementations.web_search.web_search_tool import (
WebSearchTool,
)
@@ -287,7 +288,8 @@ def construct_tools(
try:
tool_dict[db_tool_model.id] = [
WebSearchTool(tool_id=db_tool_model.id)
WebSearchTool(tool_id=db_tool_model.id),
OpenUrlTool(tool_id=db_tool_model.id),
]
except ValueError as e:
logger.error(f"Failed to initialize Internet Search Tool: {e}")

View File

@@ -15,6 +15,8 @@ from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from requests import JSONDecodeError
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import FileOrigin
from onyx.file_store.file_store import get_default_file_store
@@ -22,12 +24,16 @@ from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.base_tool import BaseTool
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import CHAT_SESSION_ID_PLACEHOLDER
from onyx.tools.models import DynamicSchemaInfo
from onyx.tools.models import MESSAGE_ID_PLACEHOLDER
from onyx.tools.models import ToolResponse
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool_implementations.custom.custom_tool_prompts import (
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
)
@@ -53,6 +59,7 @@ from onyx.tools.tool_implementations.custom.openapi_parsing import (
from onyx.tools.tool_implementations.custom.prompt import (
build_custom_image_generation_user_prompt,
)
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.utils.headers import header_list_to_header_dict
from onyx.utils.headers import HeaderItemDict
from onyx.utils.logger import setup_logger
@@ -240,6 +247,74 @@ class CustomTool(BaseTool):
"""Actual execution of the tool"""
@tool_accounting
def run_v2(self, run_context: RunContextWrapper[Any], **args: Any) -> str:
index = run_context.context.current_run_step
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=CustomToolStart(type="custom_tool_start", tool_name=self.name),
)
)
results = []
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan=f"Running {self.name}",
purpose=f"Running {self.name}",
reasoning=f"Running {self.name}",
)
)
for result in self.run(**args):
results.append(result)
# Extract data from CustomToolCallSummary within the ToolResponse
custom_summary = result.response
data = None
file_ids = None
# Handle different response types
if custom_summary.response_type in ["image", "csv"] and hasattr(
custom_summary.tool_result, "file_ids"
):
file_ids = custom_summary.tool_result.file_ids
else:
data = custom_summary.tool_result
run_context.context.global_iteration_responses.append(
IterationAnswer(
tool=self.name,
tool_id=self.id,
iteration_nr=index,
parallelization_nr=0,
question=json.dumps(args) if args else "",
reasoning=f"Running {self.name}",
data=data,
file_ids=file_ids,
cited_documents={},
additional_data=None,
response_type=custom_summary.response_type,
answer=str(data) if data else str(file_ids),
)
)
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=CustomToolDelta(
type="custom_tool_delta",
tool_name=self.name,
response_type=custom_summary.response_type,
data=data,
file_ids=file_ids,
),
)
)
# Return the last result's data as JSON string
if results:
last_summary = results[-1].response
if isinstance(last_summary.tool_result, CustomToolUserFileSnapshot):
return json.dumps(last_summary.tool_result.model_dump())
return json.dumps(last_summary.tool_result)
return json.dumps({})
def run(
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
) -> Generator[ToolResponse, None, None]:

View File

@@ -10,18 +10,28 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from typing_extensions import override
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.chat.chat_utils import combine_message_chain
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.stop_signal_checker import is_connected
from onyx.configs.app_configs import AZURE_IMAGE_API_KEY
from onyx.configs.app_configs import IMAGE_MODEL_NAME
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from onyx.db.llm import fetch_existing_llm_providers
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import save_files
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import message_to_string
from onyx.llm.utils import model_supports_image_input
from onyx.prompts.constants import GENERAL_SEP_PAT
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
from onyx.tools.tool import RunContextWrapper
@@ -29,6 +39,7 @@ from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.images.prompt import (
build_image_generation_user_prompt,
)
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -80,7 +91,7 @@ class ImageShape(str, Enum):
# override_kwargs is not supported for image generation tools
class ImageGenerationTool(Tool[None]):
_NAME = "run_image_generation"
_NAME = "image_generation"
_DESCRIPTION = (
"NEVER use generate_image unless the user specifically requests an image."
)
@@ -120,6 +131,115 @@ class ImageGenerationTool(Tool[None]):
def display_name(self) -> str:
return self._DISPLAY_NAME
@tool_accounting
def _image_generation_core(
self,
run_context: RunContextWrapper[Any],
prompt: str,
shape: ImageShape | None = None,
) -> list[GeneratedImage]:
"""Core image generation logic for run_v2."""
index = run_context.context.current_run_step
emitter = run_context.context.run_dependencies.emitter
# Emit start event
emitter.emit(
Packet(
ind=index,
obj=ImageGenerationToolStart(type="image_generation_tool_start"),
)
)
# Prepare tool arguments
tool_args = {"prompt": prompt}
if (
shape and shape != ImageShape.SQUARE
): # Only include shape if it's not the default
tool_args["shape"] = str(shape)
# Run the actual image generation tool with heartbeat handling
generated_images: list[GeneratedImage] = []
heartbeat_count = 0
for tool_response in self.run(**tool_args): # type: ignore[arg-type]
# Check if the session has been cancelled
if not is_connected(
run_context.context.chat_session_id,
run_context.context.run_dependencies.redis_client,
):
break
# Handle heartbeat responses
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# Emit heartbeat event for every iteration
emitter.emit(
Packet(
ind=index,
obj=ImageGenerationToolHeartbeat(
type="image_generation_tool_heartbeat"
),
)
)
heartbeat_count += 1
logger.debug(f"Image generation heartbeat #{heartbeat_count}")
continue
# Process the tool response to get the generated images
if tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
image_generation_responses = cast(
list[ImageGenerationResponse], tool_response.response
)
file_ids = save_files(
urls=[],
base64_files=[img.image_data for img in image_generation_responses],
)
generated_images = [
GeneratedImage(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
)
for img, file_id in zip(image_generation_responses, file_ids)
]
break
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan="Generating images",
purpose="Generating images",
reasoning="Generating images",
)
)
run_context.context.global_iteration_responses.append(
IterationAnswer(
tool=self.name,
tool_id=self.id,
iteration_nr=run_context.context.current_run_step,
parallelization_nr=0,
question=prompt,
answer="",
reasoning="",
claims=[],
generated_images=generated_images,
additional_data={},
response_type=None,
data=None,
file_ids=None,
cited_documents={},
)
)
emitter.emit(
Packet(
ind=index,
obj=ImageGenerationToolDelta(
type="image_generation_tool_delta", images=generated_images
),
)
)
return generated_images
@override
@classmethod
def is_available(cls, db_session: Session) -> bool:
@@ -194,13 +314,31 @@ class ImageGenerationTool(Tool[None]):
return None
# Since image generation is a stopping tool, we need to re-raise the error
# in the agent loop since the loop can't recover if the tool fails.
def failure_error_function(self, error: Exception) -> None:
raise error
def run_v2(
self,
run_context: RunContextWrapper[Any],
*args: Any,
**kwargs: Any,
) -> Any:
raise NotImplementedError("ImageGenerationTool.run_v2 is not implemented.")
prompt: str,
shape: ImageShape | None = None,
) -> str:
"""Run image generation via the v2 implementation.
Returns:
JSON string containing a success message
"""
# Call the core implementation
generated_images = self._image_generation_core(
run_context,
prompt,
shape,
)
# Return success message (agent stops after this tool anyway)
return f"Successfully generated {len(generated_images)} images"
def build_tool_message_content(
self, *args: ToolResponse

View File

@@ -6,18 +6,25 @@ from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.db.enums import MCPAuthenticationType
from onyx.db.models import MCPConnectionConfig
from onyx.db.models import MCPServer
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.base_tool import BaseTool
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from onyx.tools.tool_implementations.mcp.mcp_client import call_mcp_tool
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
@@ -167,6 +174,72 @@ Return ONLY a valid JSON object with the extracted arguments. If no arguments ar
)
return {}
@tool_accounting
def run_v2(self, run_context: RunContextWrapper[Any], **args: Any) -> str:
index = run_context.context.current_run_step
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=CustomToolStart(type="custom_tool_start", tool_name=self.name),
)
)
results = []
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan=f"Running {self.name}",
purpose=f"Running {self.name}",
reasoning=f"Running {self.name}",
)
)
for result in self.run(**args):
results.append(result)
# Extract data from CustomToolCallSummary within the ToolResponse
custom_summary = result.response
data = None
file_ids = None
# Handle different response types
if custom_summary.response_type in ["image", "csv"] and hasattr(
custom_summary.tool_result, "file_ids"
):
file_ids = custom_summary.tool_result.file_ids
else:
data = custom_summary.tool_result
run_context.context.global_iteration_responses.append(
IterationAnswer(
tool=self.name,
tool_id=self.id,
iteration_nr=index,
parallelization_nr=0,
question=json.dumps(args) if args else "",
reasoning=f"Running {self.name}",
data=data,
file_ids=file_ids,
cited_documents={},
additional_data=None,
response_type=custom_summary.response_type,
answer=str(data) if data else str(file_ids),
)
)
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=CustomToolDelta(
type="custom_tool_delta",
tool_name=self.name,
response_type=custom_summary.response_type,
data=data,
file_ids=file_ids,
),
)
)
# Return the last result's data as JSON string
if results:
last_summary = results[-1].response
return json.dumps(last_summary.tool_result)
return json.dumps({})
def run(
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
) -> Generator[ToolResponse, None, None]:

View File

@@ -0,0 +1,98 @@
from typing import Literal
from typing import TypedDict
import requests
from pydantic import BaseModel
from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
from onyx.utils.logger import setup_logger
logger = setup_logger()
class FileInput(TypedDict):
"""Input file to be staged in execution workspace"""
path: str
file_id: str
class WorkspaceFile(BaseModel):
"""File in execution workspace"""
path: str
kind: Literal["file", "directory"]
file_id: str | None = None
class ExecuteResponse(BaseModel):
"""Response from code execution"""
stdout: str
stderr: str
exit_code: int | None
timed_out: bool
duration_ms: int
files: list[WorkspaceFile]
class CodeInterpreterClient:
"""Client for Code Interpreter service"""
def __init__(self, base_url: str | None = CODE_INTERPRETER_BASE_URL):
if not base_url:
raise ValueError("CODE_INTERPRETER_BASE_URL not configured")
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
def execute(
self,
code: str,
stdin: str | None = None,
timeout_ms: int = 30000,
files: list[FileInput] | None = None,
) -> ExecuteResponse:
"""Execute Python code"""
url = f"{self.base_url}/v1/execute"
payload = {
"code": code,
"timeout_ms": timeout_ms,
}
if stdin is not None:
payload["stdin"] = stdin
if files:
payload["files"] = files
response = self.session.post(url, json=payload, timeout=timeout_ms / 1000 + 10)
response.raise_for_status()
return ExecuteResponse(**response.json())
def upload_file(self, file_content: bytes, filename: str) -> str:
"""Upload file to Code Interpreter and return file_id"""
url = f"{self.base_url}/v1/files"
files = {"file": (filename, file_content)}
response = self.session.post(url, files=files, timeout=30)
response.raise_for_status()
return response.json()["file_id"]
def download_file(self, file_id: str) -> bytes:
"""Download file from Code Interpreter"""
url = f"{self.base_url}/v1/files/{file_id}"
response = self.session.get(url, timeout=30)
response.raise_for_status()
return response.content
def delete_file(self, file_id: str) -> None:
"""Delete file from Code Interpreter"""
url = f"{self.base_url}/v1/files/{file_id}"
response = self.session.delete(url, timeout=10)
response.raise_for_status()

View File

@@ -1,17 +1,40 @@
import mimetypes
from collections.abc import Generator
from io import BytesIO
from typing import Any
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from typing_extensions import override
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
from onyx.configs.constants import FileOrigin
from onyx.file_store.utils import build_full_frontend_file_url
from onyx.file_store.utils import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import (
ExecuteResponse,
)
from onyx.tools.tool_implementations.python.code_interpreter_client import FileInput
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.tools.tool_result_models import LlmPythonExecutionResult
from onyx.tools.tool_result_models import PythonExecutionFile
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
@@ -23,6 +46,262 @@ _GENERIC_ERROR_MESSAGE = (
)
def _truncate_output(output: str, max_length: int, label: str = "output") -> str:
"""
Truncate output string to max_length and append truncation message if needed.
Args:
output: The original output string to truncate
max_length: Maximum length before truncation
label: Label for logging (e.g., "stdout", "stderr")
Returns:
Truncated string with truncation message appended if truncated
"""
truncated = output[:max_length]
if len(output) > max_length:
truncated += (
"\n... [output truncated, "
f"{len(output) - max_length} "
"characters omitted]"
)
logger.debug(f"Truncated {label}: {truncated}")
return truncated
def _combine_outputs(stdout: str, stderr: str) -> str:
"""
Combine stdout and stderr into a single string if both exist.
Args:
stdout: Standard output string
stderr: Standard error string
Returns:
Combined output string with labels if both exist, or the non-empty one
if only one exists, or empty string if both are empty
"""
has_stdout = bool(stdout)
has_stderr = bool(stderr)
if has_stdout and has_stderr:
return f"stdout:\n\n{stdout}\n\nstderr:\n\n{stderr}"
elif has_stdout:
return stdout
elif has_stderr:
return stderr
else:
return ""
@tool_accounting
def _python_execution_core(
run_context: RunContextWrapper[Any],
code: str,
client: CodeInterpreterClient,
tool_id: int,
) -> LlmPythonExecutionResult:
"""Core Python execution logic"""
index = run_context.context.current_run_step
emitter = run_context.context.run_dependencies.emitter
# Emit start event
emitter.emit(
Packet(
ind=index,
obj=PythonToolStart(code=code),
)
)
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan="Executing Python code",
purpose="Running Python code",
reasoning="Executing provided Python code in secure environment",
)
)
# Get all files from chat context and upload to Code Interpreter
files_to_stage: list[FileInput] = []
file_store = get_default_file_store()
# Access chat files directly from context (available after Step 0 changes)
chat_files = run_context.context.chat_files
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
# Use file content already loaded in memory
file_content = chat_file.content
# Upload to Code Interpreter
ci_file_id = client.upload_file(file_content, file_name)
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
logger.info(f"Staged file for Python execution: {file_name}")
except Exception as e:
logger.warning(f"Failed to stage file {file_name}: {e}")
try:
logger.debug(f"Executing code: {code}")
# Execute code with fixed timeout
response: ExecuteResponse = client.execute(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
)
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
response.stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
response.stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
# Handle generated files
generated_files: list[PythonExecutionFile] = []
generated_file_ids: list[str] = []
file_ids_to_cleanup: list[str] = []
for workspace_file in response.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
try:
# Download file from Code Interpreter
file_content = client.download_file(workspace_file.file_id)
# Determine MIME type from file extension
filename = workspace_file.path.split("/")[-1]
mime_type, _ = mimetypes.guess_type(filename)
# Default to binary if we can't determine the type
mime_type = mime_type or "application/octet-stream"
# Save to Onyx file store directly
onyx_file_id = file_store.save_file(
content=BytesIO(file_content),
display_name=filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=mime_type,
)
generated_files.append(
PythonExecutionFile(
filename=filename,
file_link=build_full_frontend_file_url(onyx_file_id),
)
)
generated_file_ids.append(onyx_file_id)
# Mark for cleanup
file_ids_to_cleanup.append(workspace_file.file_id)
except Exception as e:
logger.error(
f"Failed to handle generated file {workspace_file.path}: {e}"
)
# Cleanup Code Interpreter files (both generated and staged input files)
for ci_file_id in file_ids_to_cleanup:
try:
client.delete_file(ci_file_id)
except Exception as e:
# TODO: add TTL on code interpreter files themselves so they are automatically
# cleaned up after some time.
logger.error(
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
# TODO: add TTL on code interpreter files themselves so they are automatically
# cleaned up after some time.
logger.error(
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
)
# Emit delta with stdout/stderr and generated files
emitter.emit(
Packet(
ind=index,
obj=PythonToolDelta(
type="python_tool_delta",
stdout=truncated_stdout,
stderr=truncated_stderr,
file_ids=generated_file_ids,
),
)
)
# Build result with truncated output
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=response.exit_code,
timed_out=response.timed_out,
generated_files=generated_files,
error=None if response.exit_code == 0 else truncated_stderr,
)
# Store in iteration answer
run_context.context.global_iteration_responses.append(
IterationAnswer(
tool=PythonTool.__name__,
tool_id=tool_id,
iteration_nr=index,
parallelization_nr=0,
question="Execute Python code",
reasoning="Executing Python code in secure environment",
answer=_combine_outputs(truncated_stdout, truncated_stderr),
cited_documents={},
file_ids=generated_file_ids,
additional_data={
"stdout": truncated_stdout,
"stderr": truncated_stderr,
"code": code,
},
)
)
return result
except Exception as e:
logger.error(f"Python execution failed: {e}")
error_msg = str(e)
# Emit error delta
emitter.emit(
Packet(
ind=index,
obj=PythonToolDelta(
type="python_tool_delta",
stdout="",
stderr=error_msg,
file_ids=[],
),
)
)
# Return error result
return LlmPythonExecutionResult(
stdout="",
stderr=error_msg,
exit_code=-1,
timed_out=False,
generated_files=[],
error=error_msg,
)
class PythonTool(Tool[None]):
"""
Wrapper class for Python code execution tool.
@@ -32,7 +311,7 @@ class PythonTool(Tool[None]):
"""
_NAME = "python"
_DESCRIPTION = "Execute Python code in a secure, isolated environment. Never call this tool directly."
_DESCRIPTION = "Execute Python code in a secure, isolated environment."
# in the UI, call it `Code Interpreter` since this is a well known term for this tool
_DISPLAY_NAME = "Code Interpreter"
@@ -114,9 +393,23 @@ class PythonTool(Tool[None]):
run_context: RunContextWrapper[Any],
*args: Any,
**kwargs: Any,
) -> Any:
"""Not supported - Python tool is only used via v2 agent framework."""
raise ValueError(_GENERIC_ERROR_MESSAGE)
) -> str:
"""Run Python code execution via the v2 implementation.
Returns:
JSON string containing execution results (stdout, stderr, exit code, files)
"""
code = kwargs.get("code")
if not code:
raise ValueError("code is required for python execution")
# Create client and call the core implementation
client = CodeInterpreterClient()
result = _python_execution_core(run_context, code, client, self._id)
# Serialize and return
adapter = TypeAdapter(LlmPythonExecutionResult)
return adapter.dump_json(result).decode()
def run(
self, override_kwargs: None = None, **llm_kwargs: str

View File

@@ -6,11 +6,16 @@ from typing import Any
from typing import cast
from typing import TypeVar
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import PromptConfig
@@ -19,6 +24,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from onyx.chat.prune_and_merge import prune_and_merge_sections
from onyx.chat.prune_and_merge import prune_sections
from onyx.chat.stop_signal_checker import is_connected
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
@@ -35,6 +41,7 @@ from onyx.context.search.pipeline import SearchPipeline
from onyx.context.search.pipeline import section_relevance_list_impl
from onyx.db.connector import check_connectors_exist
from onyx.db.connector import check_federated_connectors_exist
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.llm.interfaces import LLM
@@ -42,6 +49,9 @@ from onyx.llm.models import PreviousMessage
from onyx.onyxbot.slack.models import SlackContext
from onyx.secondary_llm_flows.choose_search import check_if_need_search
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
@@ -55,6 +65,8 @@ from onyx.tools.tool_implementations.search_like_tool_utils import (
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.tools.tool_result_models import LlmInternalSearchResult
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
@@ -86,7 +98,7 @@ web pages. If very ambiguious, prioritize internal search or call both tools.
class SearchTool(Tool[SearchToolOverrideKwargs]):
_NAME = "run_search"
_NAME = "internal_search"
_DISPLAY_NAME = "Internal Search"
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
@@ -112,6 +124,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
bypass_acl: bool = False,
rerank_settings: RerankingDetails | None = None,
slack_context: SlackContext | None = None,
# just doing this for now since we lack dependency injection
search_pipeline_override_for_testing: SearchPipeline | None = None,
) -> None:
self.user = user
self.persona = persona
@@ -177,6 +191,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
self._id = tool_id
self.search_pipeline_override_for_testing = search_pipeline_override_for_testing
@classmethod
def is_available(cls, db_session: Session) -> bool:
@@ -260,13 +275,136 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
"""Actual tool execution"""
@tool_accounting
def run_v2(
self,
run_context: RunContextWrapper[Any],
*args: Any,
**kwargs: Any,
) -> Any:
raise NotImplementedError("SearchTool.run_v2 is not implemented.")
) -> str:
index = run_context.context.current_run_step
query = kwargs[QUERY_FIELD]
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SearchToolStart(
type="internal_search_tool_start", is_internet_search=False
),
)
)
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SearchToolDelta(
type="internal_search_tool_delta", queries=[query], documents=[]
),
)
)
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan="plan",
purpose="Searching internally for information",
reasoning=f"I am now using Internal Search to gather information on {query}",
)
)
retrieved_sections: list[InferenceSection] = []
with get_session_with_current_tenant() as search_db_session:
for tool_response in self.run(
query=query,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=search_db_session,
skip_query_analysis=True,
original_query=query,
),
):
if not is_connected(
run_context.context.chat_session_id,
run_context.context.run_dependencies.redis_client,
):
break
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(
SearchResponseSummary, tool_response.response
)
retrieved_sections = search_response_summary.top_sections
break
# Aggregate all results from all queries
# Use the current input token count from context for pruning
# This includes system prompt, history, user message, and any agent turns so far
existing_input_tokens = run_context.context.current_input_tokens
pruned_sections: list[InferenceSection] = prune_and_merge_sections(
sections=retrieved_sections,
section_relevance_list=None,
llm_config=self.llm.config,
existing_input_tokens=existing_input_tokens,
contextual_pruning_config=self.contextual_pruning_config,
)
search_results_for_query = [
LlmInternalSearchResult(
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
title=section.center_chunk.semantic_identifier,
excerpt=section.combined_content,
metadata=section.center_chunk.metadata,
unique_identifier_to_strip_away=section.center_chunk.document_id,
)
for section in pruned_sections
]
from onyx.chat.turn.models import FetchedDocumentCacheEntry
for section in pruned_sections:
unique_id = section.center_chunk.document_id
if unique_id not in run_context.context.fetched_documents_cache:
run_context.context.fetched_documents_cache[unique_id] = (
FetchedDocumentCacheEntry(
inference_section=section,
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
)
)
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SearchToolDelta(
type="internal_search_tool_delta",
queries=[],
documents=convert_inference_sections_to_search_docs(
pruned_sections, is_internet=False
),
),
)
)
run_context.context.global_iteration_responses.append(
IterationAnswer(
tool=SearchTool.__name__,
tool_id=self.id,
iteration_nr=index,
parallelization_nr=0,
question=query[0] if query else "",
reasoning=f"I am now using Internal Search to gather information on {query[0] if query else ''}",
answer="",
cited_documents={
i: inference_section
for i, inference_section in enumerate(pruned_sections)
},
queries=[query],
)
)
# Set flag to include citation requirements since we retrieved documents
run_context.context.should_cite_documents = (
run_context.context.should_cite_documents or bool(pruned_sections)
)
adapter = TypeAdapter(list[LlmInternalSearchResult])
return adapter.dump_json(search_results_for_query).decode()
def _build_response_for_specified_sections(
self, query: str
@@ -402,7 +540,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
if kg_chunk_id_zero_only:
retrieval_options.filters.kg_chunk_id_zero_only = kg_chunk_id_zero_only
search_pipeline = SearchPipeline(
search_pipeline = self.search_pipeline_override_for_testing or SearchPipeline(
search_request=SearchRequest(
query=query,
evaluation_type=(

View File

@@ -0,0 +1,234 @@
from collections.abc import Generator
from collections.abc import Sequence
from typing import Any
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from typing_extensions import override
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_content_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_content,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
truncate_search_result_content,
)
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.turn.models import FetchedDocumentCacheEntry
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.web_search import fetch_active_web_search_provider
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.server.query_and_chat.streaming_models import FetchToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SavedSearchDoc
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.tools.tool_result_models import LlmOpenUrlResult
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
logger = setup_logger()
_OPEN_URL_GENERIC_ERROR_MESSAGE = (
"OpenUrlTool should only be used by the Deep Research Agent, not via tool calling."
)
class OpenUrlTool(Tool[None]):
_NAME = "open_url"
_DESCRIPTION = "Fetch and extract full content from web pages."
_DISPLAY_NAME = "Open URL"
def __init__(self, tool_id: int) -> None:
self._id = tool_id
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self._NAME
@property
def description(self) -> str:
return self._DESCRIPTION
@property
def display_name(self) -> str:
return self._DISPLAY_NAME
@override
@classmethod
def is_available(cls, db_session: Session) -> bool:
"""Available only if an active web search provider is configured in the database."""
with get_session_with_current_tenant() as session:
provider = fetch_active_web_search_provider(session)
return provider is not None
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
"urls": {
"type": "array",
"items": {"type": "string"},
"description": "URLs to fetch content from",
},
},
"required": ["urls"],
},
},
}
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
force_run: bool = False,
) -> dict[str, Any] | None:
raise ValueError(_OPEN_URL_GENERIC_ERROR_MESSAGE)
def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
raise ValueError(_OPEN_URL_GENERIC_ERROR_MESSAGE)
@tool_accounting
def _open_url_core(
self,
run_context: RunContextWrapper[Any],
urls: Sequence[str],
content_provider: WebContentProvider,
) -> list[LlmOpenUrlResult]:
# TODO: Find better way to track index that isn't so implicit
# based on number of tool calls
index = run_context.context.current_run_step
# Create SavedSearchDoc objects from URLs for the FetchToolStart event
saved_search_docs = [SavedSearchDoc.from_url(url) for url in urls]
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=FetchToolStart(
type="fetch_tool_start", documents=saved_search_docs
),
)
)
docs = content_provider.contents(urls)
results = [
LlmOpenUrlResult(
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
content=truncate_search_result_content(doc.full_content),
unique_identifier_to_strip_away=doc.link,
)
for doc in docs
]
for doc in docs:
cache = run_context.context.fetched_documents_cache
entry = cache.setdefault(
doc.link,
FetchedDocumentCacheEntry(
inference_section=dummy_inference_section_from_internet_content(
doc
),
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
),
)
entry.inference_section = dummy_inference_section_from_internet_content(doc)
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan="plan",
purpose="Fetching content from URLs",
reasoning=f"I am now using Web Fetch to gather information on {', '.join(urls)}",
)
)
from onyx.db.tools import get_tool_by_name
from onyx.tools.tool_implementations.web_search.web_search_tool import (
WebSearchTool,
)
run_context.context.global_iteration_responses.append(
IterationAnswer(
# TODO: For now, we're using the web_search_tool_name since the web_fetch_tool_name is not a built-in tool
tool=WebSearchTool.__name__,
tool_id=get_tool_by_name(
WebSearchTool.__name__,
run_context.context.run_dependencies.db_session,
).id,
iteration_nr=index,
parallelization_nr=0,
question=f"Fetch content from URLs: {', '.join(urls)}",
reasoning=f"I am now using Web Fetch to gather information on {', '.join(urls)}",
answer="",
cited_documents={
i: dummy_inference_section_from_internet_content(d)
for i, d in enumerate(docs)
},
claims=[],
is_web_fetch=True,
)
)
# Set flag to include citation requirements since we fetched documents
run_context.context.should_cite_documents = True
return results
def run_v2(
self,
run_context: RunContextWrapper[Any],
*args: Any,
**kwargs: Any,
) -> str:
"""Run open_url using the v2 implementation"""
urls = kwargs.get("urls", [])
if not urls:
raise ValueError("urls parameter is required")
content_provider = get_default_content_provider()
if content_provider is None:
raise ValueError("No web content provider found")
retrieved_docs = self._open_url_core(run_context, urls, content_provider) # type: ignore[arg-type]
adapter = TypeAdapter(list[LlmOpenUrlResult])
return adapter.dump_json(retrieved_docs).decode()
def run(
self, override_kwargs: None = None, **llm_kwargs: str
) -> Generator[ToolResponse, None, None]:
raise ValueError(_OPEN_URL_GENERIC_ERROR_MESSAGE)
def final_result(self, *args: ToolResponse) -> JSON_ro:
raise ValueError(_OPEN_URL_GENERIC_ERROR_MESSAGE)
def build_next_prompt(
self,
prompt_builder: AnswerPromptBuilder,
tool_call_summary: ToolCallSummary,
tool_responses: list[ToolResponse],
using_tool_calling_llm: bool,
) -> AnswerPromptBuilder:
raise ValueError(_OPEN_URL_GENERIC_ERROR_MESSAGE)

View File

@@ -1,32 +1,59 @@
from collections.abc import Generator
from typing import Any
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from typing_extensions import override
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_search_result,
)
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.turn.models import FetchedDocumentCacheEntry
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.web_search import fetch_active_web_search_provider
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.tools.tool_result_models import LlmWebSearchResult
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
logger = setup_logger()
# TODO: Align on separation of Tools and SubAgents. Right now, we're only keeping this around for backwards compatibility.
QUERY_FIELD = "query"
QUERY_FIELD = "queries"
_GENERIC_ERROR_MESSAGE = "WebSearchTool should only be used by the Deep Research Agent, not via tool calling."
_OPEN_URL_GENERIC_ERROR_MESSAGE = (
"OpenUrlTool should only be used by the Deep Research Agent, not via tool calling."
)
class WebSearchTool(Tool[None]):
_NAME = "run_web_search"
_DESCRIPTION = "Search the web for information. Never call this tool."
_NAME = "web_search"
_DESCRIPTION = "Search the web for information."
_DISPLAY_NAME = "Web Search"
def __init__(self, tool_id: int) -> None:
@@ -66,7 +93,8 @@ class WebSearchTool(Tool[None]):
"type": "object",
"properties": {
QUERY_FIELD: {
"type": "string",
"type": "array",
"items": {"type": "string"},
"description": "What to search for",
},
},
@@ -89,13 +117,140 @@ class WebSearchTool(Tool[None]):
) -> str | list[str | dict[str, Any]]:
raise ValueError(_GENERIC_ERROR_MESSAGE)
@tool_accounting
def _web_search_core(
self,
run_context: RunContextWrapper[Any],
queries: list[str],
search_provider: WebSearchProvider,
) -> list[LlmWebSearchResult]:
index = run_context.context.current_run_step
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SearchToolStart(
type="internal_search_tool_start", is_internet_search=True
),
)
)
# Emit a packet in the beginning to communicate queries to the frontend
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SearchToolDelta(
type="internal_search_tool_delta",
queries=queries,
documents=[],
),
)
)
queries_str = ", ".join(queries)
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan="plan",
purpose="Searching the web for information",
reasoning=f"I am now using Web Search to gather information on {queries_str}",
)
)
# Search all queries in parallel
function_calls = [
FunctionCall(func=search_provider.search, args=(query,))
for query in queries
]
search_results_dict = run_functions_in_parallel(function_calls)
# Aggregate all results from all queries
all_hits: list[WebSearchResult] = []
for result_id in search_results_dict:
hits = search_results_dict[result_id]
if hits:
all_hits.extend(hits)
inference_sections = [
dummy_inference_section_from_internet_search_result(r) for r in all_hits
]
from onyx.agents.agent_search.dr.utils import (
convert_inference_sections_to_search_docs,
)
saved_search_docs = convert_inference_sections_to_search_docs(
inference_sections, is_internet=True
)
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SearchToolDelta(
type="internal_search_tool_delta",
queries=queries,
documents=saved_search_docs,
),
)
)
results = []
for r in all_hits:
results.append(
LlmWebSearchResult(
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
url=r.link,
title=r.title,
snippet=r.snippet or "",
unique_identifier_to_strip_away=r.link,
)
)
if r.link not in run_context.context.fetched_documents_cache:
run_context.context.fetched_documents_cache[r.link] = (
FetchedDocumentCacheEntry(
inference_section=dummy_inference_section_from_internet_search_result(
r
),
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
)
)
from onyx.db.tools import get_tool_by_name
run_context.context.global_iteration_responses.append(
IterationAnswer(
tool=WebSearchTool.__name__,
tool_id=get_tool_by_name(
WebSearchTool.__name__,
run_context.context.run_dependencies.db_session,
).id,
iteration_nr=index,
parallelization_nr=0,
question=queries_str,
reasoning=f"I am now using Web Search to gather information on {queries_str}",
answer="",
cited_documents={
i: inference_section
for i, inference_section in enumerate(inference_sections)
},
claims=[],
queries=queries,
)
)
run_context.context.should_cite_documents = True
return results
def run_v2(
self,
run_context: RunContextWrapper[Any],
*args: Any,
**kwargs: Any,
) -> Any:
raise NotImplementedError("WebSearchTool.run_v2 is not implemented.")
queries: list[str],
) -> str:
search_provider = get_default_provider()
if search_provider is None:
raise ValueError("No search provider found")
response = self._web_search_core(run_context, queries, search_provider) # type: ignore[arg-type]
adapter = TypeAdapter(list[LlmWebSearchResult])
return adapter.dump_json(response).decode()
def run(
self, override_kwargs: None = None, **llm_kwargs: str

View File

@@ -23,15 +23,15 @@ from onyx.tools.tool_implementations.search.search_tool import (
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmInternalSearchResult,
from onyx.tools.tool_implementations_v2.tool_accounting import (
tool_accounting_function,
)
from onyx.tools.tool_result_models import LlmInternalSearchResult
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
@tool_accounting
@tool_accounting_function
def _internal_search_core(
run_context: RunContextWrapper[ChatTurnContext],
queries: list[str],

View File

@@ -17,17 +17,19 @@ from onyx.file_store.utils import get_default_file_store
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations_v2.code_interpreter_client import (
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
from onyx.tools.tool_implementations_v2.code_interpreter_client import ExecuteResponse
from onyx.tools.tool_implementations_v2.code_interpreter_client import FileInput
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmPythonExecutionResult,
from onyx.tools.tool_implementations.python.code_interpreter_client import (
ExecuteResponse,
)
from onyx.tools.tool_implementations_v2.tool_result_models import PythonExecutionFile
from onyx.tools.tool_implementations.python.code_interpreter_client import FileInput
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations_v2.tool_accounting import (
tool_accounting_function,
)
from onyx.tools.tool_result_models import LlmPythonExecutionResult
from onyx.tools.tool_result_models import PythonExecutionFile
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -81,7 +83,7 @@ def _combine_outputs(stdout: str, stderr: str) -> str:
return ""
@tool_accounting
@tool_accounting_function
def _python_execution_core(
run_context: RunContextWrapper[ChatTurnContext],
code: str,

View File

@@ -31,6 +31,76 @@ def tool_accounting(func: F) -> F:
The decorated function with tool accounting functionality.
"""
@functools.wraps(func)
def wrapper(
selfobj: Any,
run_context: RunContextWrapper[ChatTurnContext],
*args: Any,
**kwargs: Any
) -> Any:
# Increment current_run_step at the beginning
run_context.context.current_run_step += 1
try:
# Call the original function (pass selfobj for class methods)
result = func(selfobj, run_context, *args, **kwargs)
# If it's a coroutine, we need to handle it differently
if inspect.iscoroutine(result):
# For async functions, we need to return a coroutine that handles the cleanup
async def async_wrapper() -> Any:
try:
return await result
finally:
_emit_section_end(run_context)
return async_wrapper()
else:
# For sync functions, emit cleanup immediately
_emit_section_end(run_context)
return result
except Exception:
# Always emit cleanup even if an exception occurred
_emit_section_end(run_context)
raise
return cast(F, wrapper)
def _emit_section_end(run_context: RunContextWrapper[ChatTurnContext]) -> None:
"""Helper function to emit section end packet and increment current_run_step."""
index = run_context.context.current_run_step
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SectionEnd(
type="section_end",
),
)
)
run_context.context.current_run_step += 1
def tool_accounting_function(func: F) -> F:
"""
Decorator for standalone functions (not methods) that adds tool accounting functionality.
Use this for standalone functions like _internal_search_core.
Use tool_accounting for class methods.
This decorator:
1. Increments the current_run_step index at the beginning
2. Emits a section end packet and increments current_run_step at the end
3. Ensures the cleanup happens even if an exception occurs
Args:
func: The function to decorate. Must take a RunContextWrapper[ChatTurnContext] as first argument.
Returns:
The decorated function with tool accounting functionality.
"""
@functools.wraps(func)
def wrapper(
run_context: RunContextWrapper[ChatTurnContext], *args: Any, **kwargs: Any
@@ -63,17 +133,3 @@ def tool_accounting(func: F) -> F:
raise
return cast(F, wrapper)
def _emit_section_end(run_context: RunContextWrapper[ChatTurnContext]) -> None:
"""Helper function to emit section end packet and increment current_run_step."""
index = run_context.context.current_run_step
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SectionEnd(
type="section_end",
),
)
)
run_context.context.current_run_step += 1

View File

@@ -40,13 +40,15 @@ from onyx.server.query_and_chat.streaming_models import SavedSearchDoc
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.tools.tool_implementations_v2.tool_result_models import LlmOpenUrlResult
from onyx.tools.tool_implementations_v2.tool_result_models import LlmWebSearchResult
from onyx.tools.tool_implementations_v2.tool_accounting import (
tool_accounting_function,
)
from onyx.tools.tool_result_models import LlmOpenUrlResult
from onyx.tools.tool_result_models import LlmWebSearchResult
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
@tool_accounting
@tool_accounting_function
def _web_search_core(
run_context: RunContextWrapper[ChatTurnContext],
queries: list[str],
@@ -192,7 +194,7 @@ changing or evolving.
"""
@tool_accounting
@tool_accounting_function
def _open_url_core(
run_context: RunContextWrapper[ChatTurnContext],
urls: Sequence[str],

View File

@@ -0,0 +1,59 @@
"""Base models for tool results with citation support."""
from typing import Any
from typing import Literal
from pydantic import BaseModel
class BaseCiteableToolResult(BaseModel):
"""Base class for tool results that can be cited."""
document_citation_number: int
unique_identifier_to_strip_away: str | None = None
type: str
class LlmInternalSearchResult(BaseCiteableToolResult):
"""Result from an internal search query"""
type: Literal["internal_search"] = "internal_search"
title: str
excerpt: str
metadata: dict[str, Any]
class LlmWebSearchResult(BaseCiteableToolResult):
"""Result from a web search query"""
type: Literal["web_search"] = "web_search"
url: str
title: str
snippet: str
class LlmOpenUrlResult(BaseCiteableToolResult):
"""Result from opening/fetching a URL"""
type: Literal["open_url"] = "open_url"
content: str
class PythonExecutionFile(BaseModel):
"""File generated during Python execution"""
filename: str
file_link: str
class LlmPythonExecutionResult(BaseModel):
"""Result from Python code execution"""
type: Literal["python_execution"] = "python_execution"
stdout: str
stderr: str
exit_code: int | None
timed_out: bool
generated_files: list[PythonExecutionFile]
error: str | None = None

View File

@@ -30,15 +30,12 @@ from onyx.file_store.utils import get_default_file_store
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations_v2.code_interpreter_client import (
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
from onyx.tools.tool_implementations_v2.python import _python_execution_core
from onyx.tools.tool_implementations_v2.python import python
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmPythonExecutionResult,
)
from onyx.tools.tool_implementations.python.python_tool import _python_execution_core
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_result_models import LlmPythonExecutionResult
# Apply initialize_file_store fixture to all tests in this module
@@ -459,7 +456,7 @@ def test_python_function_tool_wrapper(
mock_client_class.return_value = code_interpreter_client
# Call the function tool wrapper
result_coro = python.on_invoke_tool(mock_run_context, json.dumps({"code": code})) # type: ignore
result_coro = PythonTool.on_invoke_tool(mock_run_context, json.dumps({"code": code})) # type: ignore
result_json: str = asyncio.run(result_coro) # type: ignore
# Verify result is JSON string

View File

@@ -300,3 +300,70 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
mock_response=MOCK_LLM_RESPONSE,
stream_options={"include_usage": True},
)
def test_exclude_reasoning_config_for_anthropic_with_function_tool_choice() -> None:
"""Test that reasoning config is excluded for Anthropic models when using function tool choice."""
# Create an Anthropic LLM
anthropic_llm = LitellmLLM(
api_key="test_key",
timeout=30,
model_provider="anthropic",
model_name="claude-3-5-sonnet-20241022",
max_input_tokens=200000,
)
# Mock the litellm.completion function
with patch("litellm.completion") as mock_completion:
mock_response = litellm.ModelResponse(
id="msg-123",
choices=[
litellm.Choices(
finish_reason="tool_use",
index=0,
message=litellm.Message(
content=None,
role="assistant",
tool_calls=[
litellm.ChatCompletionMessageToolCall(
id="call_1",
function=LiteLLMFunction(
name="search",
arguments='{"query": "test"}',
),
type="function",
),
],
),
)
],
model="claude-3-5-sonnet-20241022",
usage=litellm.Usage(
prompt_tokens=50, completion_tokens=30, total_tokens=80
),
)
mock_completion.return_value = mock_response
messages = [HumanMessage(content="Test message")]
tools = [
{
"type": "function",
"function": {
"name": "search",
"description": "Search for information",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
},
]
# Call with a specific function tool choice
anthropic_llm.invoke(messages, tools, tool_choice="search")
# Verify that reasoning_effort and thinking are NOT passed to litellm.completion
call_kwargs = mock_completion.call_args[1]
assert "reasoning_effort" not in call_kwargs
assert "thinking" not in call_kwargs

View File

@@ -4,7 +4,6 @@ from tests.unit.onyx.chat.turn.utils import chat_turn_context
from tests.unit.onyx.chat.turn.utils import chat_turn_dependencies
from tests.unit.onyx.chat.turn.utils import fake_db_session
from tests.unit.onyx.chat.turn.utils import fake_llm
from tests.unit.onyx.chat.turn.utils import fake_model
from tests.unit.onyx.chat.turn.utils import fake_redis_client
from tests.unit.onyx.chat.turn.utils import fake_tools
@@ -13,7 +12,6 @@ __all__ = [
"chat_turn_dependencies",
"fake_db_session",
"fake_llm",
"fake_model",
"fake_redis_client",
"fake_tools",
]

View File

@@ -0,0 +1,525 @@
"""Tests for CustomTool and MCPTool run_v2() methods using dependency injection.
This test module focuses on testing the run_v2() methods for CustomTool and MCPTool,
adapted from test_adapter_v1_to_v2.py but directly testing the tool implementations.
"""
import json
import uuid
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import UUID
import pytest
from agents import RunContextWrapper
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.chat.turn.models import ChatTurnContext
from onyx.db.enums import MCPAuthenticationType
from onyx.db.enums import MCPTransport
from onyx.db.models import MCPServer
from onyx.tools.models import DynamicSchemaInfo
from onyx.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.mcp.mcp_tool import MCPTool
from tests.unit.onyx.chat.turn.utils import FakeRedis
# =============================================================================
# Fake Classes for Dependency Injection
# =============================================================================
def create_fake_database_session() -> Any:
"""Create a fake SQLAlchemy Session for testing"""
from unittest.mock import Mock
from sqlalchemy.orm import Session
fake_session = Mock(spec=Session)
fake_session.committed = False
fake_session.rolled_back = False
def mock_commit() -> None:
fake_session.committed = True
def mock_rollback() -> None:
fake_session.rolled_back = True
fake_session.commit = mock_commit
fake_session.rollback = mock_rollback
fake_session.add = Mock()
fake_session.flush = Mock()
return fake_session
class FakeEmitter:
"""Fake emitter for testing that records all emitted packets"""
def __init__(self) -> None:
self.packet_history: list[Any] = []
def emit(self, packet: Any) -> None:
self.packet_history.append(packet)
class FakeRunDependencies:
"""Fake run dependencies for testing"""
def __init__(self, db_session: Any, redis_client: FakeRedis, tool: Any) -> None:
self.db_session = db_session
self.redis_client = redis_client
self.emitter = FakeEmitter()
self.tools = [tool]
# =============================================================================
# Test Helper Functions
# =============================================================================
def create_fake_run_context(
chat_session_id: UUID,
message_id: int,
db_session: Any,
redis_client: FakeRedis,
tool: Any,
current_run_step: int = 0,
) -> RunContextWrapper[ChatTurnContext]:
"""Create a fake run context for testing"""
run_dependencies = FakeRunDependencies(db_session, redis_client, tool)
context = ChatTurnContext(
chat_session_id=chat_session_id,
message_id=message_id,
run_dependencies=run_dependencies, # type: ignore
)
context.current_run_step = current_run_step
return RunContextWrapper(context=context)
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def chat_session_id() -> UUID:
"""Fixture providing fake chat session ID."""
return uuid.uuid4()
@pytest.fixture
def message_id() -> int:
"""Fixture providing fake message ID."""
return 123
@pytest.fixture
def fake_db_session() -> Any:
"""Fixture providing a fake database session."""
return create_fake_database_session()
@pytest.fixture
def fake_redis_client() -> FakeRedis:
"""Fixture providing a fake Redis client."""
return FakeRedis()
@pytest.fixture
def openapi_schema() -> dict[str, Any]:
"""OpenAPI schema for testing."""
return {
"openapi": "3.0.0",
"info": {
"version": "1.0.0",
"title": "Test API",
"description": "A test API for testing",
},
"servers": [
{"url": "http://localhost:8080/CHAT_SESSION_ID/test/MESSAGE_ID"},
],
"paths": {
"/test/{test_id}": {
"GET": {
"summary": "Get a test item",
"operationId": "getTestItem",
"parameters": [
{
"name": "test_id",
"in": "path",
"required": True,
"schema": {"type": "string"},
}
],
},
}
},
}
@pytest.fixture
def dynamic_schema_info(chat_session_id: UUID, message_id: int) -> DynamicSchemaInfo:
"""Dynamic schema info for testing."""
return DynamicSchemaInfo(chat_session_id=chat_session_id, message_id=message_id)
@pytest.fixture
def custom_tool(
openapi_schema: dict[str, Any], dynamic_schema_info: DynamicSchemaInfo
) -> CustomTool:
"""Custom tool for testing."""
tools = build_custom_tools_from_openapi_schema_and_headers(
tool_id=-1, # dummy tool id
openapi_schema=openapi_schema,
dynamic_schema_info=dynamic_schema_info,
)
return tools[0]
@pytest.fixture
def mcp_server() -> MCPServer:
"""MCP server for testing."""
return MCPServer(
id=1,
name="test_mcp_server",
server_url="http://localhost:8080/mcp",
auth_type=MCPAuthenticationType.NONE,
transport=MCPTransport.STREAMABLE_HTTP,
)
@pytest.fixture
def mcp_tool(mcp_server: MCPServer) -> MCPTool:
"""MCP tool for testing."""
tool_definition = {
"type": "object",
"properties": {"query": {"type": "string", "description": "The search query"}},
"required": ["query"],
}
return MCPTool(
tool_id=1,
mcp_server=mcp_server,
tool_name="search",
tool_description="Search for information",
tool_definition=tool_definition,
connection_config=None,
user_email="test@example.com",
)
# =============================================================================
# Custom Tool Tests
# =============================================================================
@patch("onyx.tools.tool_implementations.custom.custom_tool.requests.request")
def test_custom_tool_run_v2_basic_invocation(
mock_request: MagicMock,
custom_tool: CustomTool,
chat_session_id: UUID,
message_id: int,
fake_db_session: Any,
fake_redis_client: FakeRedis,
dynamic_schema_info: DynamicSchemaInfo,
) -> None:
"""Test basic functionality of CustomTool.run_v2()."""
# Arrange
mock_response = MagicMock()
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"id": "456",
"name": "Test Item",
"status": "active",
}
mock_request.return_value = mock_response
fake_run_context = create_fake_run_context(
chat_session_id, message_id, fake_db_session, fake_redis_client, custom_tool
)
# Act
result = custom_tool.run_v2(fake_run_context, test_id="456")
# Assert
assert isinstance(result, str)
result_json = json.loads(result)
assert result_json["id"] == "456"
assert result_json["name"] == "Test Item"
assert result_json["status"] == "active"
# Verify HTTP request was made
expected_url = f"http://localhost:8080/{dynamic_schema_info.chat_session_id}/test/{dynamic_schema_info.message_id}/test/456"
mock_request.assert_called_once_with("GET", expected_url, json=None, headers={})
@patch("onyx.tools.tool_implementations.custom.custom_tool.requests.request")
def test_custom_tool_run_v2_iteration_tracking(
mock_request: MagicMock,
custom_tool: CustomTool,
chat_session_id: UUID,
message_id: int,
fake_db_session: Any,
fake_redis_client: FakeRedis,
) -> None:
"""Test that IterationInstructions and IterationAnswer are properly added."""
# Arrange
mock_response = MagicMock()
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {"result": "success"}
mock_request.return_value = mock_response
fake_run_context = create_fake_run_context(
chat_session_id, message_id, fake_db_session, fake_redis_client, custom_tool
)
# Act
custom_tool.run_v2(fake_run_context, test_id="789")
# Assert - verify IterationInstructions was added
assert len(fake_run_context.context.iteration_instructions) == 1
iteration_instruction = fake_run_context.context.iteration_instructions[0]
assert isinstance(iteration_instruction, IterationInstructions)
assert iteration_instruction.iteration_nr == 1
assert iteration_instruction.plan == f"Running {custom_tool.name}"
assert iteration_instruction.purpose == f"Running {custom_tool.name}"
assert iteration_instruction.reasoning == f"Running {custom_tool.name}"
# Assert - verify IterationAnswer was added
assert len(fake_run_context.context.global_iteration_responses) == 1
iteration_answer = fake_run_context.context.global_iteration_responses[0]
assert isinstance(iteration_answer, IterationAnswer)
assert iteration_answer.tool == custom_tool.name
assert iteration_answer.tool_id == custom_tool.id
assert iteration_answer.iteration_nr == 1
assert iteration_answer.parallelization_nr == 0
assert iteration_answer.question == '{"test_id": "789"}'
assert iteration_answer.reasoning == f"Running {custom_tool.name}"
assert iteration_answer.answer == "{'result': 'success'}"
assert iteration_answer.cited_documents == {}
@patch("onyx.tools.tool_implementations.custom.custom_tool.requests.request")
def test_custom_tool_run_v2_packet_emissions(
mock_request: MagicMock,
custom_tool: CustomTool,
chat_session_id: UUID,
message_id: int,
fake_db_session: Any,
fake_redis_client: FakeRedis,
) -> None:
"""Test that the correct packets are emitted during tool execution."""
# Arrange
mock_response = MagicMock()
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {"test": "data"}
mock_request.return_value = mock_response
fake_run_context = create_fake_run_context(
chat_session_id, message_id, fake_db_session, fake_redis_client, custom_tool
)
# Act
custom_tool.run_v2(fake_run_context, test_id="123")
# Assert - verify emitter captured packets
emitter = fake_run_context.context.run_dependencies.emitter
# Should have: CustomToolStart, CustomToolDelta
assert len(emitter.packet_history) >= 2
# Check CustomToolStart
start_packet = emitter.packet_history[0]
assert getattr(start_packet.obj, "type", None) == "custom_tool_start"
assert start_packet.obj.tool_name == custom_tool.name
# Check CustomToolDelta
delta_packet = emitter.packet_history[1]
assert getattr(delta_packet.obj, "type", None) == "custom_tool_delta"
assert delta_packet.obj.tool_name == custom_tool.name
@patch("onyx.tools.tool_implementations.custom.custom_tool.requests.request")
@patch("onyx.tools.tool_implementations.custom.custom_tool.get_default_file_store")
@patch("uuid.uuid4")
def test_custom_tool_run_v2_csv_response_with_file_ids(
mock_uuid: MagicMock,
mock_file_store: MagicMock,
mock_request: MagicMock,
custom_tool: CustomTool,
chat_session_id: UUID,
message_id: int,
fake_db_session: Any,
fake_redis_client: FakeRedis,
) -> None:
"""Test that CSV responses with file_ids are handled correctly."""
# Arrange
mock_uuid.return_value = uuid.UUID("12345678-1234-5678-9abc-123456789012")
mock_store_instance = MagicMock()
mock_file_store.return_value = mock_store_instance
mock_store_instance.save_file.return_value = "csv_file_123"
mock_response = MagicMock()
mock_response.headers = {"Content-Type": "text/csv"}
mock_response.content = b"name,age,city\nJohn,30,New York\nJane,25,Los Angeles"
mock_request.return_value = mock_response
fake_run_context = create_fake_run_context(
chat_session_id,
message_id,
fake_db_session,
fake_redis_client,
custom_tool,
current_run_step=2,
)
# Act
result = custom_tool.run_v2(fake_run_context, test_id="789")
# Assert - verify result contains file_ids
result_json = json.loads(result)
assert "file_ids" in result_json
assert result_json["file_ids"] == ["12345678-1234-5678-9abc-123456789012"]
# Assert - verify IterationAnswer has correct file_ids
assert len(fake_run_context.context.global_iteration_responses) == 1
iteration_answer = fake_run_context.context.global_iteration_responses[0]
assert iteration_answer.data is None
assert iteration_answer.file_ids == ["12345678-1234-5678-9abc-123456789012"]
assert iteration_answer.response_type == "csv"
# Verify file was saved
mock_store_instance.save_file.assert_called_once()
# =============================================================================
# MCP Tool Tests
# =============================================================================
@patch("onyx.tools.tool_implementations.mcp.mcp_tool.call_mcp_tool")
def test_mcp_tool_run_v2_basic_invocation(
mock_call_mcp_tool: MagicMock,
mcp_tool: MCPTool,
chat_session_id: UUID,
message_id: int,
fake_db_session: Any,
fake_redis_client: FakeRedis,
) -> None:
"""Test basic functionality of MCPTool.run_v2()."""
# Arrange
mock_call_mcp_tool.return_value = "MCP search results: test query"
fake_run_context = create_fake_run_context(
chat_session_id, message_id, fake_db_session, fake_redis_client, mcp_tool
)
# Act
result = mcp_tool.run_v2(fake_run_context, query="test search")
# Assert
assert isinstance(result, str)
# The result is already a JSON string containing {"tool_result": ...}
result_json = json.loads(result)
# The tool_result is itself a JSON string that needs to be parsed
inner_result = json.loads(result_json)
assert "tool_result" in inner_result
assert inner_result["tool_result"] == "MCP search results: test query"
# Verify MCP tool was called
mock_call_mcp_tool.assert_called_once_with(
mcp_tool.mcp_server.server_url,
mcp_tool.name,
{"query": "test search"},
connection_headers={},
transport=mcp_tool.mcp_server.transport,
)
@patch("onyx.tools.tool_implementations.mcp.mcp_tool.call_mcp_tool")
def test_mcp_tool_run_v2_iteration_tracking(
mock_call_mcp_tool: MagicMock,
mcp_tool: MCPTool,
chat_session_id: UUID,
message_id: int,
fake_db_session: Any,
fake_redis_client: FakeRedis,
) -> None:
"""Test that IterationInstructions and IterationAnswer are properly added."""
# Arrange
mock_call_mcp_tool.return_value = "MCP search results"
fake_run_context = create_fake_run_context(
chat_session_id,
message_id,
fake_db_session,
fake_redis_client,
mcp_tool,
current_run_step=1,
)
# Act
mcp_tool.run_v2(fake_run_context, query="test mcp search")
# Assert - verify IterationInstructions was added
assert len(fake_run_context.context.iteration_instructions) == 1
iteration_instruction = fake_run_context.context.iteration_instructions[0]
assert isinstance(iteration_instruction, IterationInstructions)
assert iteration_instruction.iteration_nr == 2
assert iteration_instruction.plan == f"Running {mcp_tool.name}"
assert iteration_instruction.purpose == f"Running {mcp_tool.name}"
assert iteration_instruction.reasoning == f"Running {mcp_tool.name}"
# Assert - verify IterationAnswer was added
assert len(fake_run_context.context.global_iteration_responses) == 1
iteration_answer = fake_run_context.context.global_iteration_responses[0]
assert isinstance(iteration_answer, IterationAnswer)
assert iteration_answer.tool == mcp_tool.name
assert iteration_answer.tool_id == mcp_tool.id
assert iteration_answer.iteration_nr == 2
assert iteration_answer.parallelization_nr == 0
assert iteration_answer.question == '{"query": "test mcp search"}'
assert iteration_answer.reasoning == f"Running {mcp_tool.name}"
assert iteration_answer.cited_documents == {}
@patch("onyx.tools.tool_implementations.mcp.mcp_tool.call_mcp_tool")
def test_mcp_tool_run_v2_packet_emissions(
mock_call_mcp_tool: MagicMock,
mcp_tool: MCPTool,
chat_session_id: UUID,
message_id: int,
fake_db_session: Any,
fake_redis_client: FakeRedis,
) -> None:
"""Test that the correct packets are emitted during MCP tool execution."""
# Arrange
mock_call_mcp_tool.return_value = "MCP result"
fake_run_context = create_fake_run_context(
chat_session_id, message_id, fake_db_session, fake_redis_client, mcp_tool
)
# Act
mcp_tool.run_v2(fake_run_context, query="test")
# Assert - verify emitter captured packets
emitter = fake_run_context.context.run_dependencies.emitter
# Should have: CustomToolStart, CustomToolDelta
assert len(emitter.packet_history) >= 2
# Check CustomToolStart
start_packet = emitter.packet_history[0]
assert getattr(start_packet.obj, "type", None) == "custom_tool_start"
assert start_packet.obj.tool_name == mcp_tool.name
# Check CustomToolDelta
delta_packet = emitter.packet_history[1]
assert getattr(delta_packet.obj, "type", None) == "custom_tool_delta"
assert delta_packet.obj.tool_name == mcp_tool.name

View File

@@ -0,0 +1,360 @@
"""Tests for ImageGenerationTool.run_v2() using dependency injection.
This test module focuses on testing the ImageGenerationTool.run_v2() method directly,
using dependency injection via creating fake implementations instead of using mocks.
"""
from typing import Any
from uuid import UUID
from uuid import uuid4
import pytest
from agents import RunContextWrapper
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.chat.models import PromptConfig
from onyx.chat.turn.models import ChatTurnContext
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from tests.unit.onyx.chat.turn.utils import FakeRedis
# =============================================================================
# Fake Classes for Dependency Injection
# =============================================================================
def create_fake_database_session() -> Any:
"""Create a fake SQLAlchemy Session for testing"""
from unittest.mock import Mock
from sqlalchemy.orm import Session
# Create a mock that behaves like a real Session
fake_session = Mock(spec=Session)
fake_session.committed = False
fake_session.rolled_back = False
def mock_commit() -> None:
fake_session.committed = True
def mock_rollback() -> None:
fake_session.rolled_back = True
fake_session.commit = mock_commit
fake_session.rollback = mock_rollback
fake_session.add = Mock()
fake_session.flush = Mock()
return fake_session
class FakeEmitter:
"""Fake emitter for testing that records all emitted packets"""
def __init__(self) -> None:
self.packet_history: list[Any] = []
def emit(self, packet: Any) -> None:
self.packet_history.append(packet)
class FakeRunDependencies:
"""Fake run dependencies for testing"""
def __init__(
self,
db_session: Any,
redis_client: FakeRedis,
image_generation_tool: ImageGenerationTool,
) -> None:
self.db_session = db_session
self.redis_client = redis_client
self.emitter = FakeEmitter()
self.tools = [image_generation_tool]
def get_prompt_config(self) -> PromptConfig:
return PromptConfig(
default_behavior_system_prompt="You are a helpful assistant.",
reminder="Answer the user's question.",
custom_instructions="",
datetime_aware=False,
)
class FakeCancelledRedis(FakeRedis):
"""Fake Redis client that always reports the session as cancelled."""
def exists(self, key: str) -> bool: # pragma: no cover - trivial override
return True
# =============================================================================
# Test Helper Functions
# =============================================================================
def create_fake_image_generation_tool(tool_id: int = 1) -> ImageGenerationTool:
"""Create an ImageGenerationTool instance for testing"""
return ImageGenerationTool(
api_key="fake-api-key",
api_base=None,
api_version=None,
tool_id=tool_id,
model="dall-e-3",
num_imgs=1,
)
def create_fake_run_context(
chat_session_id: UUID,
message_id: int,
db_session: Any,
redis_client: FakeRedis,
image_generation_tool: ImageGenerationTool,
) -> RunContextWrapper[ChatTurnContext]:
"""Create a fake run context for testing"""
run_dependencies = FakeRunDependencies(
db_session, redis_client, image_generation_tool
)
context = ChatTurnContext(
chat_session_id=chat_session_id,
message_id=message_id,
run_dependencies=run_dependencies, # type: ignore
)
return RunContextWrapper(context=context)
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def chat_session_id() -> UUID:
"""Fixture providing fake chat session ID."""
return uuid4()
@pytest.fixture
def message_id() -> int:
"""Fixture providing fake message ID."""
return 123
@pytest.fixture
def research_type() -> ResearchType:
"""Fixture providing fake research type."""
return ResearchType.FAST
@pytest.fixture
def fake_db_session() -> Any:
"""Fixture providing a fake database session."""
return create_fake_database_session()
@pytest.fixture
def fake_redis_client() -> FakeRedis:
"""Fixture providing a fake Redis client."""
return FakeRedis()
@pytest.fixture
def image_generation_tool() -> ImageGenerationTool:
"""Fixture providing an ImageGenerationTool with fake API credentials."""
return create_fake_image_generation_tool()
@pytest.fixture
def fake_run_context(
chat_session_id: UUID,
message_id: int,
fake_db_session: Any,
fake_redis_client: FakeRedis,
image_generation_tool: ImageGenerationTool,
) -> RunContextWrapper[ChatTurnContext]:
"""Fixture providing a complete RunContextWrapper with fake implementations."""
return create_fake_run_context(
chat_session_id,
message_id,
fake_db_session,
fake_redis_client,
image_generation_tool,
)
# =============================================================================
# Test Functions
# =============================================================================
def test_image_generation_tool_run_v2_basic_functionality(
fake_run_context: RunContextWrapper[ChatTurnContext],
image_generation_tool: ImageGenerationTool,
) -> None:
"""Test basic functionality of ImageGenerationTool.run_v2() using dependency injection.
This test verifies that the run_v2 method properly integrates with the v2 implementation.
"""
from unittest.mock import patch
# Arrange
prompt = "A beautiful sunset over mountains"
shape = "landscape"
# Create fake generated images
fake_generated_images = [
GeneratedImage(
file_id="file-123",
url="https://example.com/files/file-123",
revised_prompt="A stunning sunset over mountains with vibrant colors",
)
]
# Mock the core implementation
with patch.object(image_generation_tool, "_image_generation_core") as mock_core:
mock_core.return_value = fake_generated_images
# Act
result = image_generation_tool.run_v2(
fake_run_context, prompt=prompt, shape=shape
)
# Assert - verify result is a success message
assert isinstance(result, str)
assert "Successfully generated 1 images" in result
# Verify the core was called with correct parameters
mock_core.assert_called_once()
call_args = mock_core.call_args
# When patching a bound method, self is bound; first arg is run_context
assert call_args[0][0] == fake_run_context # run_context
assert call_args[0][1] == prompt # prompt
assert call_args[0][2] == shape # shape
def test_image_generation_tool_run_v2_missing_prompt(
fake_run_context: RunContextWrapper[ChatTurnContext],
image_generation_tool: ImageGenerationTool,
) -> None:
"""Test that run_v2 raises ValueError when prompt is missing."""
# Act & Assert
with pytest.raises(ValueError) as exc_info:
image_generation_tool.run_v2(fake_run_context)
assert "prompt is required" in str(exc_info.value)
def test_image_generation_tool_run_v2_default_shape(
fake_run_context: RunContextWrapper[ChatTurnContext],
image_generation_tool: ImageGenerationTool,
) -> None:
"""Test that run_v2 uses default shape when not provided."""
from unittest.mock import patch
# Arrange
prompt = "A cat playing with yarn"
fake_generated_images = [
GeneratedImage(
file_id="file-456",
url="https://example.com/files/file-456",
revised_prompt="A playful cat playing with colorful yarn",
)
]
# Mock the core implementation
with patch.object(image_generation_tool, "_image_generation_core") as mock_core:
mock_core.return_value = fake_generated_images
# Act - don't provide shape parameter
image_generation_tool.run_v2(fake_run_context, prompt=prompt)
# Assert - verify default shape was used
call_args = mock_core.call_args
assert call_args[0][2] == "square" # default shape
def test_image_generation_tool_run_v2_multiple_images(
fake_run_context: RunContextWrapper[ChatTurnContext],
) -> None:
"""Test that run_v2 handles multiple images correctly."""
from unittest.mock import patch
# Arrange
# Create tool that generates multiple images
multi_image_tool = ImageGenerationTool(
api_key="fake-api-key",
api_base=None,
api_version=None,
tool_id=1,
model="dall-e-3",
num_imgs=3,
)
# Update run dependencies to include the multi-image tool
fake_run_context.context.run_dependencies.tools = [multi_image_tool]
prompt = "A series of abstract patterns"
fake_generated_images = [
GeneratedImage(
file_id=f"file-{i}",
url=f"https://example.com/files/file-{i}",
revised_prompt=f"Abstract pattern variation {i}",
)
for i in range(3)
]
# Mock the core implementation
with patch.object(multi_image_tool, "_image_generation_core") as mock_core:
mock_core.return_value = fake_generated_images
# Act
result = multi_image_tool.run_v2(fake_run_context, prompt=prompt)
# Assert
assert "Successfully generated 3 images" in result
def test_image_generation_tool_run_v2_handles_cancellation_gracefully(
chat_session_id: UUID,
message_id: int,
fake_db_session: Any,
image_generation_tool: ImageGenerationTool,
) -> None:
"""Test that run_v2 handles cancellation gracefully without calling external APIs."""
from unittest.mock import patch
# Arrange - create a run context with a Redis client that always reports cancellation
cancelled_run_context = create_fake_run_context(
chat_session_id=chat_session_id,
message_id=message_id,
db_session=fake_db_session,
redis_client=FakeCancelledRedis(),
image_generation_tool=image_generation_tool,
)
prompt = "A test image prompt that should be cancelled"
# Patch the tool's run method so it does NOT call the real image API.
def fake_run(**kwargs: Any) -> Any:
# Yield a single fake ToolResponse; it will be ignored because of cancellation.
yield ToolResponse(id="ignored", response=None)
with patch.object(image_generation_tool, "run", side_effect=fake_run) as mock_run:
# Act - this should not raise, and should not call external APIs
result = image_generation_tool.run_v2(
cancelled_run_context,
prompt=prompt,
)
# Assert - when cancelled gracefully, the tool should report zero generated images
assert isinstance(result, str)
assert "Successfully generated 0 images" in result
# Verify we invoked the patched run exactly once with the expected prompt
mock_run.assert_called_once()

View File

@@ -0,0 +1,350 @@
from collections.abc import Sequence
from datetime import datetime
from typing import cast
from typing import List
from uuid import uuid4
import pytest
from agents import RunContextWrapper
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContentProvider
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchProvider
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchResult
from onyx.chat.turn.models import ChatTurnContext
from onyx.configs.constants import DocumentSource
from onyx.server.query_and_chat.streaming_models import FetchToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.tool_implementations.web_search.open_url_tool import OpenUrlTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.tools.tool_result_models import LlmOpenUrlResult
from onyx.tools.tool_result_models import LlmWebSearchResult
class MockTool:
"""Mock Tool object for testing"""
def __init__(self, tool_id: int = 1, name: str = WebSearchTool.__name__):
self.id = tool_id
self.name = name
class MockWebSearchProvider(WebSearchProvider):
"""Mock implementation of WebSearchProvider for dependency injection"""
def __init__(
self,
search_results: List[WebSearchResult] | None = None,
should_raise_exception: bool = False,
):
self.search_results = search_results or []
self.should_raise_exception = should_raise_exception
def search(self, query: str) -> List[WebSearchResult]:
if self.should_raise_exception:
raise Exception("Test exception from search provider")
return self.search_results
def contents(self, urls: Sequence[str]) -> List[WebContent]:
return []
class MockWebContentProvider(WebContentProvider):
"""Mock implementation of WebContentProvider for dependency injection"""
def __init__(
self,
content_results: List[WebContent] | None = None,
should_raise_exception: bool = False,
):
self.content_results = content_results or []
self.should_raise_exception = should_raise_exception
def contents(self, urls: Sequence[str]) -> List[WebContent]:
if self.should_raise_exception:
raise Exception("Test exception from content provider")
return self.content_results
class MockEmitter:
"""Mock emitter for dependency injection"""
def __init__(self) -> None:
self.packet_history: list[Packet] = []
def emit(self, packet: Packet) -> None:
self.packet_history.append(packet)
class MockRunDependencies:
"""Mock run dependencies for dependency injection"""
def __init__(self) -> None:
self.emitter = MockEmitter()
# Set up mock database session
from unittest.mock import MagicMock
self.db_session = MagicMock()
# Configure the scalar method to return our mock tool
mock_tool = MockTool()
self.db_session.scalar.return_value = mock_tool
def create_test_run_context(
current_run_step: int = 0,
iteration_instructions: List[IterationInstructions] | None = None,
global_iteration_responses: List[IterationAnswer] | None = None,
) -> RunContextWrapper[ChatTurnContext]:
"""Create a real RunContextWrapper with test dependencies"""
# Create test dependencies
emitter = MockEmitter()
run_dependencies = MockRunDependencies()
run_dependencies.emitter = emitter
# Create the actual context object
context = ChatTurnContext(
chat_session_id=uuid4(),
message_id=1,
current_run_step=current_run_step,
iteration_instructions=iteration_instructions or [],
global_iteration_responses=global_iteration_responses or [],
run_dependencies=run_dependencies, # type: ignore[arg-type]
)
# Create the run context wrapper
run_context = RunContextWrapper(context=context)
return run_context
def test_open_url_tool_run_v2_basic_functionality() -> None:
"""Test basic functionality of OpenUrlTool.run_v2"""
# Arrange
test_run_context = create_test_run_context()
urls = ["https://example.com/1", "https://example.com/2"]
# Create test content results
test_content_results = [
WebContent(
title="Test Content 1",
link="https://example.com/1",
full_content="This is the full content of the first page",
published_date=datetime(2024, 1, 1, 12, 0, 0),
),
WebContent(
title="Test Content 2",
link="https://example.com/2",
full_content="This is the full content of the second page",
published_date=None,
),
]
test_provider = MockWebContentProvider(content_results=test_content_results)
# Create tool instance
open_url_tool = OpenUrlTool(tool_id=1)
# Mock the get_default_content_provider to return our test provider
from unittest.mock import patch
with patch(
"onyx.tools.tool_implementations.web_search.open_url_tool.get_default_content_provider"
) as mock_get_provider:
mock_get_provider.return_value = test_provider
# Act
result_json = open_url_tool.run_v2(test_run_context, urls=urls)
from pydantic import TypeAdapter
adapter = TypeAdapter(list[LlmOpenUrlResult])
result = adapter.validate_json(result_json)
# Assert
assert len(result) == 2
assert all(isinstance(r, LlmOpenUrlResult) for r in result)
# Check first result
assert result[0].content == "This is the full content of the first page"
assert result[0].document_citation_number == -1
assert result[0].unique_identifier_to_strip_away == "https://example.com/1"
# Check second result
assert result[1].content == "This is the full content of the second page"
assert result[1].document_citation_number == -1
assert result[1].unique_identifier_to_strip_away == "https://example.com/2"
# Check that fetched_documents_cache was populated
assert len(test_run_context.context.fetched_documents_cache) == 2
assert "https://example.com/1" in test_run_context.context.fetched_documents_cache
assert "https://example.com/2" in test_run_context.context.fetched_documents_cache
# Verify cache entries have correct structure
cache_entry_1 = test_run_context.context.fetched_documents_cache[
"https://example.com/1"
]
assert cache_entry_1.document_citation_number == -1
assert cache_entry_1.inference_section is not None
# Verify context was updated
assert test_run_context.context.current_run_step == 2
assert len(test_run_context.context.iteration_instructions) == 1
assert len(test_run_context.context.global_iteration_responses) == 1
# Check iteration instruction
instruction = test_run_context.context.iteration_instructions[0]
assert isinstance(instruction, IterationInstructions)
assert instruction.iteration_nr == 1
assert instruction.purpose == "Fetching content from URLs"
assert (
"Web Fetch to gather information on https://example.com/1, https://example.com/2"
in instruction.reasoning
)
# Check iteration answer
answer = test_run_context.context.global_iteration_responses[0]
assert isinstance(answer, IterationAnswer)
assert answer.tool == WebSearchTool.__name__
assert answer.iteration_nr == 1
assert (
answer.question
== "Fetch content from URLs: https://example.com/1, https://example.com/2"
)
assert len(answer.cited_documents) == 2
# Verify emitter events were captured
emitter = cast(MockEmitter, test_run_context.context.run_dependencies.emitter)
assert len(emitter.packet_history) == 2
# Check the types of emitted events
assert isinstance(emitter.packet_history[0].obj, FetchToolStart)
assert isinstance(emitter.packet_history[1].obj, SectionEnd)
# Verify the FetchToolStart event contains the correct SavedSearchDoc objects
fetch_start_event = emitter.packet_history[0].obj
assert len(fetch_start_event.documents) == 2
assert fetch_start_event.documents[0].link == "https://example.com/1"
assert fetch_start_event.documents[1].link == "https://example.com/2"
assert fetch_start_event.documents[0].source_type == DocumentSource.WEB
def test_open_url_tool_run_v2_exception_handling() -> None:
"""Test that OpenUrlTool.run_v2 handles exceptions properly"""
# Arrange
test_run_context = create_test_run_context()
urls = ["https://example.com/1", "https://example.com/2"]
# Create a provider that will raise an exception
test_provider = MockWebContentProvider(should_raise_exception=True)
# Create tool instance
open_url_tool = OpenUrlTool(tool_id=1)
# Mock the get_default_content_provider to return our test provider
from unittest.mock import patch
with patch(
"onyx.tools.tool_implementations.web_search.open_url_tool.get_default_content_provider"
) as mock_get_provider:
mock_get_provider.return_value = test_provider
# Act & Assert
with pytest.raises(Exception, match="Test exception from content provider"):
open_url_tool.run_v2(test_run_context, urls=urls)
# Verify that even though an exception was raised, we still emitted the initial events
# and the SectionEnd packet was emitted by the decorator
emitter = test_run_context.context.run_dependencies.emitter # type: ignore[attr-defined]
assert len(emitter.packet_history) == 2 # FetchToolStart and SectionEnd
# Check the types of emitted events
assert isinstance(emitter.packet_history[0].obj, FetchToolStart)
assert isinstance(emitter.packet_history[1].obj, SectionEnd)
# Verify that the decorator properly handled the exception and updated current_run_step
assert (
test_run_context.context.current_run_step == 2
) # Should be 2 after proper handling
def test_open_url_tool_run_v2_cache_deduplication() -> None:
"""Test that WebSearchTool.run_v2 and OpenUrlTool.run_v2 share fetched_documents_cache for the same URL"""
# Arrange
test_run_context = create_test_run_context()
test_url = "https://example.com/1"
# First, do a web search that returns this URL
search_results = [
WebSearchResult(
title="Test Result",
link=test_url,
author="Test Author",
published_date=datetime(2024, 1, 1, 12, 0, 0),
snippet="This is a test snippet",
),
]
# Then, fetch the full content for the same URL
content_results = [
WebContent(
title="Test Content",
link=test_url,
full_content="This is the full content of the page",
published_date=datetime(2024, 1, 1, 12, 0, 0),
),
]
search_provider = MockWebSearchProvider(search_results=search_results)
content_provider = MockWebContentProvider(content_results=content_results)
web_search_tool = WebSearchTool(tool_id=1)
open_url_tool = OpenUrlTool(tool_id=1)
from unittest.mock import patch
from pydantic import TypeAdapter
# Act - first do web_search via run_v2
with patch(
"onyx.tools.tool_implementations.web_search.web_search_tool.get_default_provider"
) as mock_get_provider:
mock_get_provider.return_value = search_provider
search_result_json = web_search_tool.run_v2(
test_run_context, queries=["test query"]
)
adapter_ws = TypeAdapter(list[LlmWebSearchResult])
search_result = adapter_ws.validate_json(search_result_json)
# Verify search result
assert len(search_result) == 1
assert search_result[0].url == test_url
# Verify cache was populated by web_search
assert test_url in test_run_context.context.fetched_documents_cache
# Now run open_url via run_v2
with patch(
"onyx.tools.tool_implementations.web_search.open_url_tool.get_default_content_provider"
) as mock_get_content_provider:
mock_get_content_provider.return_value = content_provider
open_result_json = open_url_tool.run_v2(test_run_context, urls=[test_url])
adapter_ou = TypeAdapter(list[LlmOpenUrlResult])
open_result = adapter_ou.validate_json(open_result_json)
# Verify open_url result
assert len(open_result) == 1
assert open_result[0].content == "This is the full content of the page"
# Verify cache still has the same entry (not duplicated)
assert len(test_run_context.context.fetched_documents_cache) == 1
assert test_url in test_run_context.context.fetched_documents_cache
cache_entry_after_open = test_run_context.context.fetched_documents_cache[test_url]
# Verify that the cache entry was updated with the full content
# (The inference section should be updated, not replaced)
assert cache_entry_after_open.document_citation_number == -1

View File

@@ -0,0 +1,447 @@
"""Tests for SearchTool.run_v2() using dependency injection via search_pipeline_override_for_testing.
This test module focuses on testing the SearchTool.run_v2() method directly,
using the search_pipeline_override_for_testing parameter for dependency injection
instead of using mocks.
"""
from typing import Any
from uuid import UUID
from uuid import uuid4
import pytest
from agents import RunContextWrapper
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.chat.models import PromptConfig
from onyx.chat.turn.models import ChatTurnContext
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import SavedSearchDoc
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from tests.unit.onyx.chat.turn.utils import create_test_inference_chunk
from tests.unit.onyx.chat.turn.utils import FakeQuery
from tests.unit.onyx.chat.turn.utils import FakeRedis
from tests.unit.onyx.chat.turn.utils import FakeResult
# =============================================================================
# Fake Classes for Dependency Injection
# =============================================================================
def create_fake_database_session() -> Any:
"""Create a fake SQLAlchemy Session for testing"""
from unittest.mock import Mock
from sqlalchemy.orm import Session
# Create a mock that behaves like a real Session
fake_session = Mock(spec=Session)
fake_session.committed = False
fake_session.rolled_back = False
def mock_commit() -> None:
fake_session.committed = True
def mock_rollback() -> None:
fake_session.rolled_back = True
fake_session.commit = mock_commit
fake_session.rollback = mock_rollback
fake_session.add = Mock()
fake_session.flush = Mock()
fake_session.query = Mock(return_value=FakeQuery())
fake_session.execute = Mock(return_value=FakeResult())
return fake_session
class FakeSearchQuery:
"""Fake SearchQuery for testing"""
def __init__(self) -> None:
self.search_type = SearchType.SEMANTIC
self.filters = IndexFilters(access_control_list=None)
self.recency_bias_multiplier = 1.0
class FakeSearchPipeline:
"""Fake SearchPipeline for dependency injection in SearchTool"""
def __init__(self, sections: list[InferenceSection] | None = None) -> None:
self.sections = sections or []
self.search_query = FakeSearchQuery()
@property
def merged_retrieved_sections(self) -> list[InferenceSection]:
return self.sections
@property
def final_context_sections(self) -> list[InferenceSection]:
return self.sections
@property
def section_relevance(self) -> list | None:
return None
class FakeLLMConfig:
"""Fake LLM config for testing"""
def __init__(self) -> None:
self.max_input_tokens = 128000 # Default GPT-4 context
self.model_name = "gpt-4"
self.model_provider = "openai"
class FakeLLM:
"""Fake LLM for testing"""
def __init__(self) -> None:
self.config = FakeLLMConfig()
def log_model_configs(self) -> None:
pass
class FakePersona:
"""Fake Persona for testing"""
def __init__(self) -> None:
self.id = 1
self.name = "Test Persona"
self.chunks_above = None
self.chunks_below = None
self.llm_relevance_filter = False
self.llm_filter_extraction = False
self.recency_bias = "auto"
self.prompt_ids = []
self.document_sets = []
self.num_chunks = None
self.llm_model_version_override = None
class FakeUser:
"""Fake User for testing"""
def __init__(self) -> None:
self.id = 1
self.email = "test@example.com"
class FakeEmitter:
"""Fake emitter for testing that records all emitted packets"""
def __init__(self) -> None:
self.packet_history: list[Any] = []
def emit(self, packet: Any) -> None:
self.packet_history.append(packet)
class FakeRunDependencies:
"""Fake run dependencies for testing"""
def __init__(
self, db_session: Any, redis_client: FakeRedis, search_tool: SearchTool
) -> None:
self.db_session = db_session
self.redis_client = redis_client
self.emitter = FakeEmitter()
self.tools = [search_tool]
def get_prompt_config(self) -> PromptConfig:
return PromptConfig(
default_behavior_system_prompt="You are a helpful assistant.",
reminder="Answer the user's question.",
custom_instructions="",
datetime_aware=False,
)
# =============================================================================
# Test Helper Functions
# =============================================================================
def create_search_section_with_semantic_id(
document_id: str, semantic_identifier: str, content: str, link: str
) -> InferenceSection:
"""Create a test inference section with custom semantic_identifier"""
chunk = create_test_inference_chunk(
document_id=document_id,
semantic_identifier=semantic_identifier,
content=content,
link=link,
)
return InferenceSection(
center_chunk=chunk,
chunks=[chunk],
combined_content=content,
)
def create_fake_search_pipeline_with_results(
sections: list[InferenceSection] | None = None,
) -> FakeSearchPipeline:
"""Create a fake search pipeline with test results"""
if sections is None:
sections = [
create_search_section_with_semantic_id(
document_id="doc1",
semantic_identifier="test_doc_1",
content="First test document content",
link="https://example.com/doc1",
),
create_search_section_with_semantic_id(
document_id="doc2",
semantic_identifier="test_doc_2",
content="Second test document content",
link="https://example.com/doc2",
),
]
return FakeSearchPipeline(sections=sections)
def create_search_tool_with_fake_pipeline(
search_pipeline: FakeSearchPipeline,
db_session: Any | None = None,
tool_id: int = 1,
) -> SearchTool:
"""Create a SearchTool instance with a fake search pipeline for testing"""
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.context.search.enums import LLMEvaluationType
fake_db_session = db_session or create_fake_database_session()
fake_llm = FakeLLM()
fake_persona = FakePersona()
fake_user = FakeUser()
return SearchTool(
tool_id=tool_id,
db_session=fake_db_session,
user=fake_user, # type: ignore
persona=fake_persona, # type: ignore
retrieval_options=None,
prompt_config=PromptConfig(
default_behavior_system_prompt="You are a helpful assistant.",
reminder="Answer the user's question.",
custom_instructions="",
datetime_aware=False,
),
llm=fake_llm, # type: ignore
fast_llm=fake_llm, # type: ignore
evaluation_type=LLMEvaluationType.SKIP,
answer_style_config=AnswerStyleConfig(citation_config=CitationConfig()),
document_pruning_config=DocumentPruningConfig(),
search_pipeline_override_for_testing=search_pipeline, # type: ignore
)
def create_fake_run_context(
chat_session_id: UUID,
message_id: int,
db_session: Any,
redis_client: FakeRedis,
search_tool: SearchTool,
) -> RunContextWrapper[ChatTurnContext]:
"""Create a fake run context for testing"""
run_dependencies = FakeRunDependencies(db_session, redis_client, search_tool)
context = ChatTurnContext(
chat_session_id=chat_session_id,
message_id=message_id,
run_dependencies=run_dependencies, # type: ignore
)
return RunContextWrapper(context=context)
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def chat_session_id() -> UUID:
"""Fixture providing fake chat session ID."""
return uuid4()
@pytest.fixture
def message_id() -> int:
"""Fixture providing fake message ID."""
return 123
@pytest.fixture
def research_type() -> ResearchType:
"""Fixture providing fake research type."""
return ResearchType.FAST
@pytest.fixture
def fake_db_session() -> Any:
"""Fixture providing a fake database session."""
return create_fake_database_session()
@pytest.fixture
def fake_redis_client() -> FakeRedis:
"""Fixture providing a fake Redis client."""
return FakeRedis()
@pytest.fixture
def fake_search_pipeline() -> FakeSearchPipeline:
"""Fixture providing a fake search pipeline with default test results."""
return create_fake_search_pipeline_with_results()
@pytest.fixture
def search_tool(
fake_search_pipeline: FakeSearchPipeline, fake_db_session: Any
) -> SearchTool:
"""Fixture providing a SearchTool with fake search pipeline."""
return create_search_tool_with_fake_pipeline(fake_search_pipeline, fake_db_session)
@pytest.fixture
def fake_run_context(
chat_session_id: UUID,
message_id: int,
fake_db_session: Any,
fake_redis_client: FakeRedis,
search_tool: SearchTool,
) -> RunContextWrapper[ChatTurnContext]:
"""Fixture providing a complete RunContextWrapper with fake implementations."""
return create_fake_run_context(
chat_session_id, message_id, fake_db_session, fake_redis_client, search_tool
)
# =============================================================================
# Test Functions
# =============================================================================
def test_search_tool_run_v2_basic_functionality(
fake_run_context: RunContextWrapper[ChatTurnContext],
search_tool: SearchTool,
fake_db_session: Any,
) -> None:
"""Test basic functionality of SearchTool.run_v2() using dependency injection.
This test mirrors the original test_internal_search_core_basic_functionality but
uses the run_v2 method with search_pipeline_override_for_testing instead of mocks.
"""
from unittest.mock import patch
# Arrange
query = "test search query"
# Create a session context manager
class FakeSessionContextManager:
def __enter__(self) -> Any:
return fake_db_session
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
pass
# Act - patch get_session_with_current_tenant to return our fake session
with patch(
"onyx.tools.tool_implementations.search.search_tool.get_session_with_current_tenant",
return_value=FakeSessionContextManager(),
):
result = search_tool.run_v2(fake_run_context, query=query)
# Assert - verify result is a JSON string
import json
result_list = json.loads(result)
assert isinstance(result_list, list)
assert len(result_list) == 2
# Verify result contains InternalSearchResult objects (as dicts in JSON)
assert result_list[0]["unique_identifier_to_strip_away"] == "doc1"
assert result_list[0]["title"] == "test_doc_1"
assert result_list[0]["excerpt"] == "First test document content"
assert result_list[0]["document_citation_number"] == -1
assert result_list[1]["unique_identifier_to_strip_away"] == "doc2"
assert result_list[1]["title"] == "test_doc_2"
assert result_list[1]["excerpt"] == "Second test document content"
assert result_list[1]["document_citation_number"] == -1
# Verify context was updated (decorator increments current_run_step)
assert fake_run_context.context.current_run_step == 2
assert len(fake_run_context.context.iteration_instructions) == 1
assert len(fake_run_context.context.global_iteration_responses) == 1
# Verify fetched_documents_cache was populated
assert len(fake_run_context.context.fetched_documents_cache) == 2
assert "doc1" in fake_run_context.context.fetched_documents_cache
assert "doc2" in fake_run_context.context.fetched_documents_cache
# Verify cache entries have correct structure
cache_entry_1 = fake_run_context.context.fetched_documents_cache["doc1"]
assert cache_entry_1.document_citation_number == -1
assert cache_entry_1.inference_section is not None
assert cache_entry_1.inference_section.center_chunk.document_id == "doc1"
cache_entry_2 = fake_run_context.context.fetched_documents_cache["doc2"]
assert cache_entry_2.document_citation_number == -1
assert cache_entry_2.inference_section.center_chunk.document_id == "doc2"
# Check iteration instruction
instruction = fake_run_context.context.iteration_instructions[0]
assert isinstance(instruction, IterationInstructions)
assert instruction.iteration_nr == 1
assert instruction.purpose == "Searching internally for information"
assert (
"I am now using Internal Search to gather information on"
in instruction.reasoning
)
# Check iteration answer
answer = fake_run_context.context.global_iteration_responses[0]
assert isinstance(answer, IterationAnswer)
assert answer.tool == SearchTool.__name__
assert answer.tool_id == search_tool.id
assert answer.iteration_nr == 1
assert answer.answer == ""
assert len(answer.cited_documents) == 2
# Verify emitter events were captured
emitter = fake_run_context.context.run_dependencies.emitter
# Should have: SearchToolStart, SearchToolDelta (query), SearchToolDelta (docs)
assert len(emitter.packet_history) >= 3
# Check the types of emitted events
assert isinstance(emitter.packet_history[0].obj, SearchToolStart)
assert isinstance(emitter.packet_history[1].obj, SearchToolDelta)
assert isinstance(emitter.packet_history[2].obj, SearchToolDelta)
# Check the first SearchToolDelta (query)
first_delta = emitter.packet_history[1].obj
assert first_delta.queries == [query]
assert first_delta.documents == []
# Check the second SearchToolDelta (documents)
second_delta = emitter.packet_history[2].obj
assert second_delta.queries == []
assert len(second_delta.documents) == 2
# Verify the SavedSearchDoc objects
first_doc = second_delta.documents[0]
assert isinstance(first_doc, SavedSearchDoc)
assert first_doc.document_id == "doc1"
assert first_doc.semantic_identifier == "test_doc_1"

View File

@@ -18,9 +18,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmInternalSearchResult,
)
from onyx.tools.tool_result_models import LlmInternalSearchResult
from tests.unit.onyx.chat.turn.utils import create_test_inference_chunk
from tests.unit.onyx.chat.turn.utils import create_test_inference_section
from tests.unit.onyx.chat.turn.utils import FakeQuery

View File

@@ -25,10 +25,10 @@ from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.tools.tool_implementations_v2.tool_result_models import LlmOpenUrlResult
from onyx.tools.tool_implementations_v2.tool_result_models import LlmWebSearchResult
from onyx.tools.tool_implementations_v2.web import _open_url_core
from onyx.tools.tool_implementations_v2.web import _web_search_core
from onyx.tools.tool_result_models import LlmOpenUrlResult
from onyx.tools.tool_result_models import LlmWebSearchResult
class MockTool:

View File

@@ -0,0 +1,398 @@
from collections.abc import Sequence
from datetime import datetime
from typing import cast
from typing import List
from uuid import uuid4
import pytest
from agents import RunContextWrapper
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchProvider
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchResult
from onyx.chat.turn.models import ChatTurnContext
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import SavedSearchDoc
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.tools.tool_result_models import LlmWebSearchResult
class MockTool:
"""Mock Tool object for testing"""
def __init__(self, tool_id: int = 1, name: str = WebSearchTool.__name__):
self.id = tool_id
self.name = name
class MockWebSearchProvider(WebSearchProvider):
"""Mock implementation of WebSearchProvider for dependency injection"""
def __init__(
self,
search_results: List[WebSearchResult] | None = None,
should_raise_exception: bool = False,
):
self.search_results = search_results or []
self.should_raise_exception = should_raise_exception
def search(self, query: str) -> List[WebSearchResult]:
if self.should_raise_exception:
raise Exception("Test exception from search provider")
return self.search_results
def contents(self, urls: Sequence[str]) -> List[WebContent]:
return []
class MockEmitter:
"""Mock emitter for dependency injection"""
def __init__(self) -> None:
self.packet_history: list[Packet] = []
def emit(self, packet: Packet) -> None:
self.packet_history.append(packet)
class MockRunDependencies:
"""Mock run dependencies for dependency injection"""
def __init__(self) -> None:
self.emitter = MockEmitter()
# Set up mock database session
from unittest.mock import MagicMock
self.db_session = MagicMock()
# Configure the scalar method to return our mock tool
mock_tool = MockTool()
self.db_session.scalar.return_value = mock_tool
def create_test_run_context(
current_run_step: int = 0,
iteration_instructions: List[IterationInstructions] | None = None,
global_iteration_responses: List[IterationAnswer] | None = None,
) -> RunContextWrapper[ChatTurnContext]:
"""Create a real RunContextWrapper with test dependencies"""
# Create test dependencies
emitter = MockEmitter()
run_dependencies = MockRunDependencies()
run_dependencies.emitter = emitter
# Create the actual context object
context = ChatTurnContext(
chat_session_id=uuid4(),
message_id=1,
current_run_step=current_run_step,
iteration_instructions=iteration_instructions or [],
global_iteration_responses=global_iteration_responses or [],
run_dependencies=run_dependencies, # type: ignore[arg-type]
)
# Create the run context wrapper
run_context = RunContextWrapper(context=context)
return run_context
def test_web_search_tool_run_v2_basic_functionality() -> None:
"""Test basic functionality of WebSearchTool.run_v2 with a single query"""
# Arrange
test_run_context = create_test_run_context()
queries = ["test search query"]
# Create test search results
test_search_results = [
WebSearchResult(
title="Test Result 1",
link="https://example.com/1",
author="Test Author",
published_date=datetime(2024, 1, 1, 12, 0, 0),
snippet="This is a test snippet 1",
),
WebSearchResult(
title="Test Result 2",
link="https://example.com/2",
author=None,
published_date=None,
snippet="This is a test snippet 2",
),
]
test_provider = MockWebSearchProvider(search_results=test_search_results)
# Create tool instance
web_search_tool = WebSearchTool(tool_id=1)
# Mock the get_default_provider to return our test provider
from unittest.mock import patch
with patch(
"onyx.tools.tool_implementations.web_search.web_search_tool.get_default_provider"
) as mock_get_provider:
mock_get_provider.return_value = test_provider
# Act
result_json = web_search_tool.run_v2(test_run_context, queries=queries)
# Parse the JSON result
from pydantic import TypeAdapter
adapter = TypeAdapter(list[LlmWebSearchResult])
result = adapter.validate_json(result_json)
# Assert
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(r, LlmWebSearchResult) for r in result)
# Check first result
assert result[0].title == "Test Result 1"
assert result[0].url == "https://example.com/1"
assert result[0].snippet == "This is a test snippet 1"
assert result[0].document_citation_number == -1
assert result[0].unique_identifier_to_strip_away == "https://example.com/1"
# Check second result
assert result[1].title == "Test Result 2"
assert result[1].url == "https://example.com/2"
assert result[1].snippet == "This is a test snippet 2"
assert result[1].document_citation_number == -1
assert result[1].unique_identifier_to_strip_away == "https://example.com/2"
# Check that fetched_documents_cache was populated
assert len(test_run_context.context.fetched_documents_cache) == 2
assert "https://example.com/1" in test_run_context.context.fetched_documents_cache
assert "https://example.com/2" in test_run_context.context.fetched_documents_cache
# Verify cache entries have correct structure
cache_entry_1 = test_run_context.context.fetched_documents_cache[
"https://example.com/1"
]
assert cache_entry_1.document_citation_number == -1
assert cache_entry_1.inference_section is not None
# Verify context was updated
assert test_run_context.context.current_run_step == 2
assert len(test_run_context.context.iteration_instructions) == 1
assert len(test_run_context.context.global_iteration_responses) == 1
# Check iteration instruction
instruction = test_run_context.context.iteration_instructions[0]
assert isinstance(instruction, IterationInstructions)
assert instruction.iteration_nr == 1
assert instruction.purpose == "Searching the web for information"
assert (
"Web Search to gather information on test search query" in instruction.reasoning
)
# Check iteration answer
answer = test_run_context.context.global_iteration_responses[0]
assert isinstance(answer, IterationAnswer)
assert answer.tool == WebSearchTool.__name__
assert answer.iteration_nr == 1
assert answer.question == queries[0]
# Verify emitter events were captured
emitter = cast(MockEmitter, test_run_context.context.run_dependencies.emitter)
assert len(emitter.packet_history) == 4
# Check the types of emitted events
assert isinstance(emitter.packet_history[0].obj, SearchToolStart)
assert isinstance(emitter.packet_history[1].obj, SearchToolDelta)
assert isinstance(emitter.packet_history[2].obj, SearchToolDelta)
assert isinstance(emitter.packet_history[3].obj, SectionEnd)
# Verify first SearchToolDelta has queries but empty documents
first_search_delta = cast(SearchToolDelta, emitter.packet_history[1].obj)
assert first_search_delta.queries == queries
assert first_search_delta.documents == []
# Verify second SearchToolDelta contains SavedSearchDoc objects for favicon display
search_tool_delta = cast(SearchToolDelta, emitter.packet_history[2].obj)
assert len(search_tool_delta.documents) == 2
# Verify documents have correct properties for frontend favicon display
doc1 = search_tool_delta.documents[0]
assert isinstance(doc1, SavedSearchDoc)
assert doc1.link == "https://example.com/1"
assert (
doc1.semantic_identifier == "https://example.com/1"
) # semantic_identifier is the link for web results
assert doc1.blurb == "Test Result 1" # title is stored in blurb
assert doc1.source_type == DocumentSource.WEB
assert doc1.is_internet is True
def test_web_search_tool_run_v2_exception_handling() -> None:
"""Test that WebSearchTool.run_v2 handles exceptions properly"""
# Arrange
test_run_context = create_test_run_context()
queries = ["test search query"]
# Create a provider that will raise an exception
test_provider = MockWebSearchProvider(should_raise_exception=True)
# Create tool instance
web_search_tool = WebSearchTool(tool_id=1)
# Mock the get_default_provider to return our test provider
from unittest.mock import patch
with patch(
"onyx.tools.tool_implementations.web_search.web_search_tool.get_default_provider"
) as mock_get_provider:
mock_get_provider.return_value = test_provider
# Act & Assert
with pytest.raises(Exception, match="Test exception from search provider"):
web_search_tool.run_v2(test_run_context, queries=queries)
# Verify that even though an exception was raised, we still emitted the initial events
# and the SectionEnd packet was emitted by the decorator
emitter = test_run_context.context.run_dependencies.emitter # type: ignore[attr-defined]
assert (
len(emitter.packet_history) == 3
) # SearchToolStart, first SearchToolDelta, and SectionEnd
# Check the types of emitted events
assert isinstance(emitter.packet_history[0].obj, SearchToolStart)
assert isinstance(emitter.packet_history[1].obj, SearchToolDelta)
assert isinstance(emitter.packet_history[2].obj, SectionEnd)
# Verify that the decorator properly handled the exception and updated current_run_step
assert (
test_run_context.context.current_run_step == 2
) # Should be 2 after proper handling
def test_web_search_tool_run_v2_multiple_queries() -> None:
"""Test WebSearchTool.run_v2 with multiple queries searched in parallel"""
# Arrange
test_run_context = create_test_run_context()
queries = ["first query", "second query"]
# Create a mock provider that returns different results based on the query
class MultiQueryMockProvider(WebSearchProvider):
def search(self, query: str) -> List[WebSearchResult]:
if query == "first query":
return [
WebSearchResult(
title="First Result 1",
link="https://example.com/first1",
author="Author 1",
published_date=datetime(2024, 1, 1, 12, 0, 0),
snippet="Snippet for first query result 1",
),
WebSearchResult(
title="First Result 2",
link="https://example.com/first2",
author=None,
published_date=None,
snippet="Snippet for first query result 2",
),
]
elif query == "second query":
return [
WebSearchResult(
title="Second Result 1",
link="https://example.com/second1",
author="Author 2",
published_date=datetime(2024, 2, 1, 12, 0, 0),
snippet="Snippet for second query result 1",
),
]
return []
def contents(self, urls: Sequence[str]) -> List[WebContent]:
return []
test_provider = MultiQueryMockProvider()
# Create tool instance
web_search_tool = WebSearchTool(tool_id=1)
# Mock the get_default_provider to return our test provider
from unittest.mock import patch
with patch(
"onyx.tools.tool_implementations.web_search.web_search_tool.get_default_provider"
) as mock_get_provider:
mock_get_provider.return_value = test_provider
# Act
result_json = web_search_tool.run_v2(test_run_context, queries=queries)
# Parse the JSON result
from pydantic import TypeAdapter
adapter = TypeAdapter(list[LlmWebSearchResult])
result = adapter.validate_json(result_json)
# Assert
assert isinstance(result, list)
# Should have 3 total results (2 from first query + 1 from second query)
assert len(result) == 3
assert all(isinstance(r, LlmWebSearchResult) for r in result)
# Verify all results are present (order may vary due to parallel execution)
titles = {r.title for r in result}
assert "First Result 1" in titles
assert "First Result 2" in titles
assert "Second Result 1" in titles
# Check that fetched_documents_cache was populated with all URLs
assert len(test_run_context.context.fetched_documents_cache) == 3
# Verify context was updated
assert test_run_context.context.current_run_step == 2
assert len(test_run_context.context.iteration_instructions) == 1
assert len(test_run_context.context.global_iteration_responses) == 1
# Check iteration instruction contains both queries
instruction = test_run_context.context.iteration_instructions[0]
assert isinstance(instruction, IterationInstructions)
assert instruction.iteration_nr == 1
assert instruction.purpose == "Searching the web for information"
assert "first query" in instruction.reasoning
assert "second query" in instruction.reasoning
# Check iteration answer
answer = test_run_context.context.global_iteration_responses[0]
assert isinstance(answer, IterationAnswer)
assert answer.tool == WebSearchTool.__name__
assert answer.iteration_nr == 1
assert "first query" in answer.question
assert "second query" in answer.question
# Verify emitter events were captured
emitter = cast(MockEmitter, test_run_context.context.run_dependencies.emitter)
assert len(emitter.packet_history) == 4
# Check the types of emitted events
assert isinstance(emitter.packet_history[0].obj, SearchToolStart)
assert isinstance(emitter.packet_history[1].obj, SearchToolDelta)
assert isinstance(emitter.packet_history[2].obj, SearchToolDelta)
assert isinstance(emitter.packet_history[3].obj, SectionEnd)
# Check that first SearchToolDelta contains both queries (with empty documents)
first_search_delta = cast(SearchToolDelta, emitter.packet_history[1].obj)
assert first_search_delta.queries is not None
assert len(first_search_delta.queries) == 2
assert "first query" in first_search_delta.queries
assert "second query" in first_search_delta.queries
assert first_search_delta.documents == []
# Check that second SearchToolDelta contains documents
second_search_delta = cast(SearchToolDelta, emitter.packet_history[2].obj)
assert (
len(second_search_delta.documents) == 3
) # 2 from first query + 1 from second query