mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-07 07:52:44 +00:00
Compare commits
3 Commits
cli/v0.2.1
...
richard/se
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
764e7f44f6 | ||
|
|
1c376c66d4 | ||
|
|
735b3c4c02 |
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
@@ -312,11 +313,15 @@ def query(
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Error running tool",
|
||||
data={"tool_name": tool.name, "error": str(e)},
|
||||
data={
|
||||
"tool_name": tool.name,
|
||||
"error": str(e),
|
||||
"stack_trace": traceback.format_exc(),
|
||||
},
|
||||
)
|
||||
)
|
||||
# Treat the error as the tool output so the framework can continue
|
||||
error_output = f"Error: {str(e)}"
|
||||
error_output = tool.failure_error_function(e)
|
||||
tool_outputs[call_id] = error_output
|
||||
output = error_output
|
||||
|
||||
@@ -328,8 +333,17 @@ def query(
|
||||
),
|
||||
)
|
||||
else:
|
||||
not_found_output = f"Tool {name} not found"
|
||||
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
|
||||
with function_span(f"tool_not_found_{name}") as span_fn:
|
||||
not_found_output = {
|
||||
"error": True,
|
||||
"error_type": "TOOL_NOT FOUND",
|
||||
"message": f"The tool {name} does not exist or is not registered.",
|
||||
"available_tools": [tool.name for tool in tools],
|
||||
}
|
||||
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
|
||||
span_fn.span_data.input = arguments_str
|
||||
span_fn.span_data.output = not_found_output
|
||||
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call_output",
|
||||
details=ToolCallOutputStreamItem(
|
||||
|
||||
@@ -56,7 +56,6 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.adapter_v1_to_v2 import force_use_tool_to_function_tool_names
|
||||
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
|
||||
from onyx.tools.force import filter_tools_for_force_tool_use
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
@@ -111,10 +110,6 @@ def _run_agent_loop(
|
||||
available_tools: Sequence[Tool] = (
|
||||
dependencies.tools if iteration_count < MAX_ITERATIONS else []
|
||||
)
|
||||
if force_use_tool and force_use_tool.force_use:
|
||||
available_tools = filter_tools_for_force_tool_use(
|
||||
list(available_tools), force_use_tool
|
||||
)
|
||||
memories = get_memories(dependencies.user_or_none, dependencies.db_session)
|
||||
# TODO: The system is rather prompt-cache efficient except for rebuilding the system prompt.
|
||||
# The biggest offender is when we hit max iterations and then all the tool calls cannot
|
||||
|
||||
@@ -47,6 +47,7 @@ from onyx.llm.llm_provider_options import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.llm_provider_options import VERTEX_LOCATION_KWARG
|
||||
from onyx.llm.model_response import ModelResponse
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.llm.utils import is_anthropic_model
|
||||
from onyx.llm.utils import is_true_openai_model
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
from onyx.server.utils import mask_string
|
||||
@@ -466,15 +467,22 @@ class LitellmLLM(LLM):
|
||||
from litellm.exceptions import Timeout, RateLimitError
|
||||
|
||||
tool_choice_formatted: dict[str, Any] | str | None
|
||||
function_tool_choice = (
|
||||
tool_choice and tool_choice not in STANDARD_TOOL_CHOICE_OPTIONS
|
||||
)
|
||||
if not tools:
|
||||
tool_choice_formatted = None
|
||||
elif tool_choice and tool_choice not in STANDARD_TOOL_CHOICE_OPTIONS:
|
||||
elif function_tool_choice:
|
||||
tool_choice_formatted = {
|
||||
"type": "function",
|
||||
"function": {"name": tool_choice},
|
||||
}
|
||||
else:
|
||||
tool_choice_formatted = tool_choice
|
||||
exclude_reasoning_config_for_anthropic = (
|
||||
is_anthropic_model(self.config.model_name, self.config.model_provider)
|
||||
and function_tool_choice
|
||||
)
|
||||
|
||||
is_reasoning = model_is_reasoning_model(
|
||||
self.config.model_name, self.config.model_provider
|
||||
@@ -531,12 +539,16 @@ class LitellmLLM(LLM):
|
||||
),
|
||||
**(
|
||||
{"thinking": {"type": "enabled", "budget_tokens": 10000}}
|
||||
if reasoning_effort and is_reasoning
|
||||
if reasoning_effort
|
||||
and is_reasoning
|
||||
and not exclude_reasoning_config_for_anthropic
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
{"reasoning_effort": reasoning_effort}
|
||||
if reasoning_effort and is_reasoning
|
||||
if reasoning_effort
|
||||
and is_reasoning
|
||||
and not exclude_reasoning_config_for_anthropic
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
|
||||
@@ -7,41 +7,28 @@ from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypedDict
|
||||
from typing import Union
|
||||
|
||||
from litellm import AllMessageValues
|
||||
from litellm import LiteLLMLoggingObj
|
||||
from litellm.completion_extras.litellm_responses_transformation.transformation import (
|
||||
LiteLLMResponsesTransformationHandler,
|
||||
)
|
||||
from litellm.completion_extras.litellm_responses_transformation.transformation import (
|
||||
OpenAiResponsesToChatCompletionStreamIterator,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
|
||||
try:
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
extract_images_from_message,
|
||||
)
|
||||
except ImportError:
|
||||
extract_images_from_message = None # type: ignore[assignment]
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
extract_images_from_message,
|
||||
)
|
||||
from litellm.llms.ollama.chat.transformation import OllamaChatCompletionResponseIterator
|
||||
from litellm.llms.ollama.chat.transformation import OllamaChatConfig
|
||||
from litellm.llms.ollama.common_utils import OllamaError
|
||||
|
||||
try:
|
||||
from litellm.types.llms.ollama import OllamaChatCompletionMessage
|
||||
except ImportError:
|
||||
|
||||
class OllamaChatCompletionMessage(TypedDict, total=False): # type: ignore[no-redef]
|
||||
"""Fallback for LiteLLM versions where this TypedDict was removed."""
|
||||
|
||||
role: str
|
||||
content: Optional[str]
|
||||
images: Optional[List[Any]]
|
||||
thinking: Optional[str]
|
||||
tool_calls: Optional[List["OllamaToolCall"]]
|
||||
|
||||
|
||||
from litellm.llms.openai.chat.gpt_5_transformation import OpenAIGPT5Config
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from litellm.types.llms.ollama import OllamaChatCompletionMessage
|
||||
from litellm.types.llms.ollama import OllamaToolCall
|
||||
from litellm.types.llms.ollama import OllamaToolCallFunction
|
||||
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||
@@ -52,43 +39,7 @@ from litellm.utils import verbose_logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
if extract_images_from_message is None:
|
||||
|
||||
def extract_images_from_message(
|
||||
message: AllMessageValues,
|
||||
) -> Optional[List[Any]]:
|
||||
"""Fallback for LiteLLM versions that dropped extract_images_from_message."""
|
||||
|
||||
images: List[Any] = []
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
return None
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, Dict):
|
||||
continue
|
||||
|
||||
item_type = item.get("type")
|
||||
if item_type == "image_url":
|
||||
image_url = item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
if image_url.get("url"):
|
||||
images.append(image_url)
|
||||
elif image_url:
|
||||
images.append(image_url)
|
||||
elif item_type in {"input_image", "image"}:
|
||||
image_value = item.get("image")
|
||||
if image_value:
|
||||
images.append(image_value)
|
||||
|
||||
return images or None
|
||||
|
||||
|
||||
def _patch_ollama_transform_request() -> None:
|
||||
"""
|
||||
Patches OllamaChatConfig.transform_request to handle reasoning content
|
||||
and tool calls properly for Ollama chat completions.
|
||||
"""
|
||||
def _patch_ollama_transform_request_so_tool_calls_streamed() -> None:
|
||||
if (
|
||||
getattr(OllamaChatConfig.transform_request, "__name__", "")
|
||||
== "_patched_transform_request"
|
||||
@@ -180,11 +131,7 @@ def _patch_ollama_transform_request() -> None:
|
||||
OllamaChatConfig.transform_request = _patched_transform_request # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_ollama_chunk_parser() -> None:
|
||||
"""
|
||||
Patches OllamaChatCompletionResponseIterator.chunk_parser to properly handle
|
||||
reasoning content and content in streaming responses.
|
||||
"""
|
||||
def _patch_ollama_chunk_parser_so_reasoning_streamed() -> None:
|
||||
if (
|
||||
getattr(OllamaChatCompletionResponseIterator.chunk_parser, "__name__", "")
|
||||
== "_patched_chunk_parser"
|
||||
@@ -312,11 +259,7 @@ def _patch_ollama_chunk_parser() -> None:
|
||||
OllamaChatCompletionResponseIterator.chunk_parser = _patched_chunk_parser # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_openai_responses_chunk_parser() -> None:
|
||||
"""
|
||||
Patches OpenAiResponsesToChatCompletionStreamIterator.chunk_parser to properly
|
||||
handle OpenAI Responses API streaming format and convert it to chat completion format.
|
||||
"""
|
||||
def _patch_openai_responses_chunk_parser_so_reasoning_streamed() -> None:
|
||||
if (
|
||||
getattr(
|
||||
OpenAiResponsesToChatCompletionStreamIterator.chunk_parser,
|
||||
@@ -483,6 +426,174 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
OpenAiResponsesToChatCompletionStreamIterator.chunk_parser = _patched_openai_responses_chunk_parser # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_litellm_responses_transformation_handler_so_tool_choice_formatted() -> None:
|
||||
if (
|
||||
getattr(
|
||||
LiteLLMResponsesTransformationHandler.transform_request,
|
||||
"__name__",
|
||||
"",
|
||||
)
|
||||
== "_patched_transform_request"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List["AllMessageValues"],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
litellm_logging_obj: "LiteLLMLoggingObj",
|
||||
client: Optional[Any] = None,
|
||||
) -> dict:
|
||||
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
|
||||
|
||||
(
|
||||
input_items,
|
||||
instructions,
|
||||
) = self.convert_chat_completion_messages_to_responses_api(messages)
|
||||
|
||||
# Build responses API request using the reverse transformation logic
|
||||
responses_api_request = ResponsesAPIOptionalRequestParams()
|
||||
|
||||
# Set instructions if we found a system message
|
||||
if instructions:
|
||||
responses_api_request["instructions"] = instructions
|
||||
|
||||
# Map optional parameters
|
||||
for key, value in optional_params.items():
|
||||
if value is None:
|
||||
continue
|
||||
if key in ("max_tokens", "max_completion_tokens"):
|
||||
responses_api_request["max_output_tokens"] = value
|
||||
elif key == "tools" and value is not None:
|
||||
# Convert chat completion tools to responses API tools format
|
||||
responses_api_request["tools"] = (
|
||||
LiteLLMResponsesTransformationHandler._convert_tools_to_responses_format(
|
||||
None, cast(List[Dict[str, Any]], value)
|
||||
)
|
||||
)
|
||||
elif key == "tool_choice" and value is not None:
|
||||
# TODO Right now, we're only supporting function tools and tool choice
|
||||
# mode. In reality this should support all the possible tool_choice inputs
|
||||
# documented in the API docs, but these are the only formats Onyx uses.
|
||||
responses_api_request["tool_choice"] = (
|
||||
{
|
||||
"type": "function",
|
||||
"name": (
|
||||
value
|
||||
if isinstance(value, str)
|
||||
else value["function"]["name"]
|
||||
),
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
)
|
||||
elif key in ResponsesAPIOptionalRequestParams.__annotations__.keys():
|
||||
responses_api_request[key] = value # type: ignore
|
||||
elif key in ("metadata"):
|
||||
responses_api_request["metadata"] = value
|
||||
elif key in ("previous_response_id"):
|
||||
responses_api_request["previous_response_id"] = value
|
||||
elif key == "reasoning_effort":
|
||||
responses_api_request["reasoning"] = self._map_reasoning_effort(value)
|
||||
|
||||
# Get stream parameter from litellm_params if not in optional_params
|
||||
stream = optional_params.get("stream") or litellm_params.get("stream", False)
|
||||
verbose_logger.debug(f"Chat provider: Stream parameter: {stream}")
|
||||
|
||||
# Ensure stream is properly set in the request
|
||||
if stream:
|
||||
responses_api_request["stream"] = True
|
||||
|
||||
# Handle session management if previous_response_id is provided
|
||||
previous_response_id = optional_params.get("previous_response_id")
|
||||
if previous_response_id:
|
||||
# Use the existing session handler for responses API
|
||||
verbose_logger.debug(
|
||||
f"Chat provider: Warning ignoring previous response ID: {previous_response_id}"
|
||||
)
|
||||
|
||||
# Convert back to responses API format for the actual request
|
||||
|
||||
api_model = model
|
||||
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
setattr(litellm_logging_obj, "call_type", CallTypes.responses.value)
|
||||
|
||||
request_data = {
|
||||
"model": api_model,
|
||||
"input": input_items,
|
||||
"litellm_logging_obj": litellm_logging_obj,
|
||||
**litellm_params,
|
||||
"client": client,
|
||||
}
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Chat provider: Final request model={api_model}, input_items={len(input_items)}"
|
||||
)
|
||||
|
||||
# Add non-None values from responses_api_request
|
||||
for key, value in responses_api_request.items():
|
||||
if value is not None:
|
||||
if key == "instructions" and instructions:
|
||||
request_data["instructions"] = instructions
|
||||
elif key == "stream_options" and isinstance(value, dict):
|
||||
request_data["stream_options"] = value.get("include_obfuscation")
|
||||
elif key == "user": # string can't be longer than 64 characters
|
||||
if isinstance(value, str) and len(value) <= 64:
|
||||
request_data["user"] = value
|
||||
else:
|
||||
request_data[key] = value
|
||||
|
||||
return request_data
|
||||
|
||||
_patched_transform_request.__name__ = "_patched_transform_request"
|
||||
LiteLLMResponsesTransformationHandler.transform_request = _patched_transform_request # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_litellm_openai_gpt5_coonfig_allow_tool_choice_for_responses_bridge() -> None:
|
||||
if (
|
||||
getattr(
|
||||
OpenAIGPT5Config.get_supported_openai_params,
|
||||
"__name__",
|
||||
"",
|
||||
)
|
||||
== "_patched_get_supported_openai_params"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_get_supported_openai_params(self, model: str) -> list:
|
||||
# Call the parent class method directly (can't use super() in monkey patches)
|
||||
base_gpt_series_params = OpenAIGPTConfig.get_supported_openai_params(
|
||||
self, model=model
|
||||
)
|
||||
gpt_5_only_params = ["reasoning_effort"]
|
||||
base_gpt_series_params.extend(gpt_5_only_params)
|
||||
|
||||
non_supported_params = [
|
||||
"logprobs",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"top_logprobs",
|
||||
"stop",
|
||||
]
|
||||
|
||||
return [
|
||||
param
|
||||
for param in base_gpt_series_params
|
||||
if param not in non_supported_params
|
||||
]
|
||||
|
||||
_patched_get_supported_openai_params.__name__ = (
|
||||
"_patched_get_supported_openai_params"
|
||||
)
|
||||
OpenAIGPT5Config.get_supported_openai_params = _patched_get_supported_openai_params # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -491,10 +602,14 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatConfig.transform_request for reasoning content support
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_request for tool choice formatting
|
||||
- Patching OpenAIGPT5Config.get_supported_openai_params for tool choice support for responses bridge
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_ollama_transform_request_so_tool_calls_streamed()
|
||||
_patch_ollama_chunk_parser_so_reasoning_streamed()
|
||||
_patch_openai_responses_chunk_parser_so_reasoning_streamed()
|
||||
_patch_litellm_responses_transformation_handler_so_tool_choice_formatted()
|
||||
_patch_litellm_openai_gpt5_coonfig_allow_tool_choice_for_responses_bridge()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -935,6 +935,10 @@ def is_true_openai_model(model_provider: str, model_name: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_anthropic_model(model_name: str, model_provider: str) -> bool:
|
||||
return model_provider == "anthropic" or "claude" in model_name.lower()
|
||||
|
||||
|
||||
def model_needs_formatting_reenabled(model_name: str) -> bool:
|
||||
# See https://simonwillison.net/tags/markdown/ for context on why this is needed
|
||||
# for OpenAI reasoning models to have correct markdown generation
|
||||
|
||||
@@ -34,12 +34,8 @@ from onyx.server.features.web_search.models import WebSearchToolResponse
|
||||
from onyx.server.features.web_search.models import WebSearchWithContentResponse
|
||||
from onyx.server.manage.web_search.models import WebContentProviderView
|
||||
from onyx.server.manage.web_search.models import WebSearchProviderView
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
LlmOpenUrlResult,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
LlmWebSearchResult,
|
||||
)
|
||||
from onyx.tools.tool_result_models import LlmOpenUrlResult
|
||||
from onyx.tools.tool_result_models import LlmWebSearchResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import WebContentProviderType
|
||||
from shared_configs.enums import WebSearchProviderType
|
||||
|
||||
@@ -2,12 +2,8 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
LlmOpenUrlResult,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
LlmWebSearchResult,
|
||||
)
|
||||
from onyx.tools.tool_result_models import LlmOpenUrlResult
|
||||
from onyx.tools.tool_result_models import LlmWebSearchResult
|
||||
from shared_configs.enums import WebContentProviderType
|
||||
from shared_configs.enums import WebSearchProviderType
|
||||
|
||||
|
||||
@@ -13,12 +13,12 @@ from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.built_in_tools_v2 import BUILT_IN_TOOL_MAP_V2
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.mcp.mcp_tool import MCPTool
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
|
||||
# Type alias for tools that need custom handling
|
||||
CustomOrMcpTool = Union[CustomTool, MCPTool]
|
||||
@@ -29,7 +29,6 @@ def is_custom_or_mcp_tool(tool: Tool) -> bool:
|
||||
return isinstance(tool, CustomTool) or isinstance(tool, MCPTool)
|
||||
|
||||
|
||||
@tool_accounting
|
||||
async def _tool_run_wrapper(
|
||||
run_context: RunContextWrapper[ChatTurnContext], tool: Tool, json_string: str
|
||||
) -> list[Any]:
|
||||
@@ -37,7 +36,11 @@ async def _tool_run_wrapper(
|
||||
Wrapper function to adapt Tool.run() to FunctionTool.on_invoke_tool() signature.
|
||||
"""
|
||||
args = json.loads(json_string) if json_string else {}
|
||||
|
||||
# Manually handle tool accounting (increment step)
|
||||
run_context.context.current_run_step += 1
|
||||
index = run_context.context.current_run_step
|
||||
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
@@ -95,6 +98,16 @@ async def _tool_run_wrapper(
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Emit section end and increment step (manually handle tool accounting cleanup)
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SectionEnd(type="section_end"),
|
||||
)
|
||||
)
|
||||
run_context.context.current_run_step += 1
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class BaseTool(Tool[None]):
|
||||
run_context: RunContextWrapper[Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
) -> str: # we expect JSON format returned to the model
|
||||
raise NotImplementedError("BaseTool.run_v2 is not implemented.")
|
||||
|
||||
def build_next_prompt(
|
||||
|
||||
@@ -2,8 +2,6 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
class ForceUseTool(BaseModel):
|
||||
# Could be not a forced usage of the tool but still have args, in which case
|
||||
@@ -17,12 +15,3 @@ class ForceUseTool(BaseModel):
|
||||
def build_openai_tool_choice_dict(self) -> dict[str, Any]:
|
||||
"""Build dict in the format that OpenAI expects which tells them to use this tool."""
|
||||
return {"type": "function", "name": self.tool_name}
|
||||
|
||||
|
||||
def filter_tools_for_force_tool_use(
|
||||
tools: list[Tool], force_use_tool: ForceUseTool
|
||||
) -> list[Tool]:
|
||||
if not force_use_tool.force_use:
|
||||
return tools
|
||||
|
||||
return [tool for tool in tools if tool.name == force_use_tool.tool_name]
|
||||
|
||||
@@ -100,6 +100,8 @@ class Tool(abc.ABC, Generic[OVERRIDE_T]):
|
||||
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
# run_V2 should be what's used moving forwards.
|
||||
# run is only there to support deep research.
|
||||
@abc.abstractmethod
|
||||
def run_v2(
|
||||
self,
|
||||
@@ -109,6 +111,15 @@ class Tool(abc.ABC, Generic[OVERRIDE_T]):
|
||||
) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def failure_error_function(self, error: Exception) -> str:
|
||||
"""
|
||||
This function defines what is returned to the LLM when the tool fails.
|
||||
By default, it returns a generic error message.
|
||||
Subclasses may override to provide a more specific error message, or re-raise the error
|
||||
for a hard error in the framework.
|
||||
"""
|
||||
return f"Error: {str(error)}"
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(
|
||||
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any
|
||||
|
||||
@@ -54,6 +54,7 @@ from onyx.tools.tool_implementations.python.python_tool import (
|
||||
PythonTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.open_url_tool import OpenUrlTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
@@ -287,7 +288,8 @@ def construct_tools(
|
||||
|
||||
try:
|
||||
tool_dict[db_tool_model.id] = [
|
||||
WebSearchTool(tool_id=db_tool_model.id)
|
||||
WebSearchTool(tool_id=db_tool_model.id),
|
||||
OpenUrlTool(tool_id=db_tool_model.id),
|
||||
]
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to initialize Internet Search Tool: {e}")
|
||||
|
||||
@@ -15,6 +15,8 @@ from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -22,12 +24,16 @@ from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.base_tool import BaseTool
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import CHAT_SESSION_ID_PLACEHOLDER
|
||||
from onyx.tools.models import DynamicSchemaInfo
|
||||
from onyx.tools.models import MESSAGE_ID_PLACEHOLDER
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import RunContextWrapper
|
||||
from onyx.tools.tool_implementations.custom.custom_tool_prompts import (
|
||||
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
|
||||
)
|
||||
@@ -53,6 +59,7 @@ from onyx.tools.tool_implementations.custom.openapi_parsing import (
|
||||
from onyx.tools.tool_implementations.custom.prompt import (
|
||||
build_custom_image_generation_user_prompt,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.utils.headers import header_list_to_header_dict
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -240,6 +247,74 @@ class CustomTool(BaseTool):
|
||||
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
@tool_accounting
|
||||
def run_v2(self, run_context: RunContextWrapper[Any], **args: Any) -> str:
|
||||
index = run_context.context.current_run_step
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CustomToolStart(type="custom_tool_start", tool_name=self.name),
|
||||
)
|
||||
)
|
||||
results = []
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan=f"Running {self.name}",
|
||||
purpose=f"Running {self.name}",
|
||||
reasoning=f"Running {self.name}",
|
||||
)
|
||||
)
|
||||
for result in self.run(**args):
|
||||
results.append(result)
|
||||
# Extract data from CustomToolCallSummary within the ToolResponse
|
||||
custom_summary = result.response
|
||||
data = None
|
||||
file_ids = None
|
||||
|
||||
# Handle different response types
|
||||
if custom_summary.response_type in ["image", "csv"] and hasattr(
|
||||
custom_summary.tool_result, "file_ids"
|
||||
):
|
||||
file_ids = custom_summary.tool_result.file_ids
|
||||
else:
|
||||
data = custom_summary.tool_result
|
||||
run_context.context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=self.name,
|
||||
tool_id=self.id,
|
||||
iteration_nr=index,
|
||||
parallelization_nr=0,
|
||||
question=json.dumps(args) if args else "",
|
||||
reasoning=f"Running {self.name}",
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
cited_documents={},
|
||||
additional_data=None,
|
||||
response_type=custom_summary.response_type,
|
||||
answer=str(data) if data else str(file_ids),
|
||||
)
|
||||
)
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CustomToolDelta(
|
||||
type="custom_tool_delta",
|
||||
tool_name=self.name,
|
||||
response_type=custom_summary.response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
),
|
||||
)
|
||||
)
|
||||
# Return the last result's data as JSON string
|
||||
if results:
|
||||
last_summary = results[-1].response
|
||||
if isinstance(last_summary.tool_result, CustomToolUserFileSnapshot):
|
||||
return json.dumps(last_summary.tool_result.model_dump())
|
||||
return json.dumps(last_summary.tool_result)
|
||||
return json.dumps({})
|
||||
|
||||
def run(
|
||||
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
|
||||
@@ -10,18 +10,28 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.chat.chat_utils import combine_message_chain
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.stop_signal_checker import is_connected
|
||||
from onyx.configs.app_configs import AZURE_IMAGE_API_KEY
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.prompts.constants import GENERAL_SEP_PAT
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import RunContextWrapper
|
||||
@@ -29,6 +39,7 @@ from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.images.prompt import (
|
||||
build_image_generation_user_prompt,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -80,7 +91,7 @@ class ImageShape(str, Enum):
|
||||
|
||||
# override_kwargs is not supported for image generation tools
|
||||
class ImageGenerationTool(Tool[None]):
|
||||
_NAME = "run_image_generation"
|
||||
_NAME = "image_generation"
|
||||
_DESCRIPTION = (
|
||||
"NEVER use generate_image unless the user specifically requests an image."
|
||||
)
|
||||
@@ -120,6 +131,115 @@ class ImageGenerationTool(Tool[None]):
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
@tool_accounting
|
||||
def _image_generation_core(
|
||||
self,
|
||||
run_context: RunContextWrapper[Any],
|
||||
prompt: str,
|
||||
shape: ImageShape | None = None,
|
||||
) -> list[GeneratedImage]:
|
||||
"""Core image generation logic for run_v2."""
|
||||
index = run_context.context.current_run_step
|
||||
emitter = run_context.context.run_dependencies.emitter
|
||||
|
||||
# Emit start event
|
||||
emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=ImageGenerationToolStart(type="image_generation_tool_start"),
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare tool arguments
|
||||
tool_args = {"prompt": prompt}
|
||||
if (
|
||||
shape and shape != ImageShape.SQUARE
|
||||
): # Only include shape if it's not the default
|
||||
tool_args["shape"] = str(shape)
|
||||
|
||||
# Run the actual image generation tool with heartbeat handling
|
||||
generated_images: list[GeneratedImage] = []
|
||||
heartbeat_count = 0
|
||||
|
||||
for tool_response in self.run(**tool_args): # type: ignore[arg-type]
|
||||
# Check if the session has been cancelled
|
||||
if not is_connected(
|
||||
run_context.context.chat_session_id,
|
||||
run_context.context.run_dependencies.redis_client,
|
||||
):
|
||||
break
|
||||
|
||||
# Handle heartbeat responses
|
||||
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
|
||||
# Emit heartbeat event for every iteration
|
||||
emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=ImageGenerationToolHeartbeat(
|
||||
type="image_generation_tool_heartbeat"
|
||||
),
|
||||
)
|
||||
)
|
||||
heartbeat_count += 1
|
||||
logger.debug(f"Image generation heartbeat #{heartbeat_count}")
|
||||
continue
|
||||
|
||||
# Process the tool response to get the generated images
|
||||
if tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
image_generation_responses = cast(
|
||||
list[ImageGenerationResponse], tool_response.response
|
||||
)
|
||||
file_ids = save_files(
|
||||
urls=[],
|
||||
base64_files=[img.image_data for img in image_generation_responses],
|
||||
)
|
||||
generated_images = [
|
||||
GeneratedImage(
|
||||
file_id=file_id,
|
||||
url=build_frontend_file_url(file_id),
|
||||
revised_prompt=img.revised_prompt,
|
||||
)
|
||||
for img, file_id in zip(image_generation_responses, file_ids)
|
||||
]
|
||||
break
|
||||
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan="Generating images",
|
||||
purpose="Generating images",
|
||||
reasoning="Generating images",
|
||||
)
|
||||
)
|
||||
run_context.context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=self.name,
|
||||
tool_id=self.id,
|
||||
iteration_nr=run_context.context.current_run_step,
|
||||
parallelization_nr=0,
|
||||
question=prompt,
|
||||
answer="",
|
||||
reasoning="",
|
||||
claims=[],
|
||||
generated_images=generated_images,
|
||||
additional_data={},
|
||||
response_type=None,
|
||||
data=None,
|
||||
file_ids=None,
|
||||
cited_documents={},
|
||||
)
|
||||
)
|
||||
emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=ImageGenerationToolDelta(
|
||||
type="image_generation_tool_delta", images=generated_images
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return generated_images
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
@@ -194,13 +314,31 @@ class ImageGenerationTool(Tool[None]):
|
||||
|
||||
return None
|
||||
|
||||
# Since image generation is a stopping tool, we need to re-raise the error
|
||||
# in the agent loop since the loop can't recover if the tool fails.
|
||||
def failure_error_function(self, error: Exception) -> None:
|
||||
raise error
|
||||
|
||||
def run_v2(
|
||||
self,
|
||||
run_context: RunContextWrapper[Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
raise NotImplementedError("ImageGenerationTool.run_v2 is not implemented.")
|
||||
prompt: str,
|
||||
shape: ImageShape | None = None,
|
||||
) -> str:
|
||||
"""Run image generation via the v2 implementation.
|
||||
|
||||
Returns:
|
||||
JSON string containing a success message
|
||||
"""
|
||||
# Call the core implementation
|
||||
generated_images = self._image_generation_core(
|
||||
run_context,
|
||||
prompt,
|
||||
shape,
|
||||
)
|
||||
|
||||
# Return success message (agent stops after this tool anyway)
|
||||
return f"Successfully generated {len(generated_images)} images"
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
|
||||
@@ -6,18 +6,25 @@ from typing import cast
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.models import MCPConnectionConfig
|
||||
from onyx.db.models import MCPServer
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.base_tool import BaseTool
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import RunContextWrapper
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from onyx.tools.tool_implementations.mcp.mcp_client import call_mcp_tool
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -167,6 +174,72 @@ Return ONLY a valid JSON object with the extracted arguments. If no arguments ar
|
||||
)
|
||||
return {}
|
||||
|
||||
@tool_accounting
|
||||
def run_v2(self, run_context: RunContextWrapper[Any], **args: Any) -> str:
|
||||
index = run_context.context.current_run_step
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CustomToolStart(type="custom_tool_start", tool_name=self.name),
|
||||
)
|
||||
)
|
||||
results = []
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan=f"Running {self.name}",
|
||||
purpose=f"Running {self.name}",
|
||||
reasoning=f"Running {self.name}",
|
||||
)
|
||||
)
|
||||
for result in self.run(**args):
|
||||
results.append(result)
|
||||
# Extract data from CustomToolCallSummary within the ToolResponse
|
||||
custom_summary = result.response
|
||||
data = None
|
||||
file_ids = None
|
||||
|
||||
# Handle different response types
|
||||
if custom_summary.response_type in ["image", "csv"] and hasattr(
|
||||
custom_summary.tool_result, "file_ids"
|
||||
):
|
||||
file_ids = custom_summary.tool_result.file_ids
|
||||
else:
|
||||
data = custom_summary.tool_result
|
||||
run_context.context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=self.name,
|
||||
tool_id=self.id,
|
||||
iteration_nr=index,
|
||||
parallelization_nr=0,
|
||||
question=json.dumps(args) if args else "",
|
||||
reasoning=f"Running {self.name}",
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
cited_documents={},
|
||||
additional_data=None,
|
||||
response_type=custom_summary.response_type,
|
||||
answer=str(data) if data else str(file_ids),
|
||||
)
|
||||
)
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CustomToolDelta(
|
||||
type="custom_tool_delta",
|
||||
tool_name=self.name,
|
||||
response_type=custom_summary.response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
),
|
||||
)
|
||||
)
|
||||
# Return the last result's data as JSON string
|
||||
if results:
|
||||
last_summary = results[-1].response
|
||||
return json.dumps(last_summary.tool_result)
|
||||
return json.dumps({})
|
||||
|
||||
def run(
|
||||
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class FileInput(TypedDict):
|
||||
"""Input file to be staged in execution workspace"""
|
||||
|
||||
path: str
|
||||
file_id: str
|
||||
|
||||
|
||||
class WorkspaceFile(BaseModel):
|
||||
"""File in execution workspace"""
|
||||
|
||||
path: str
|
||||
kind: Literal["file", "directory"]
|
||||
file_id: str | None = None
|
||||
|
||||
|
||||
class ExecuteResponse(BaseModel):
|
||||
"""Response from code execution"""
|
||||
|
||||
stdout: str
|
||||
stderr: str
|
||||
exit_code: int | None
|
||||
timed_out: bool
|
||||
duration_ms: int
|
||||
files: list[WorkspaceFile]
|
||||
|
||||
|
||||
class CodeInterpreterClient:
|
||||
"""Client for Code Interpreter service"""
|
||||
|
||||
def __init__(self, base_url: str | None = CODE_INTERPRETER_BASE_URL):
|
||||
if not base_url:
|
||||
raise ValueError("CODE_INTERPRETER_BASE_URL not configured")
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.session = requests.Session()
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
stdin: str | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
files: list[FileInput] | None = None,
|
||||
) -> ExecuteResponse:
|
||||
"""Execute Python code"""
|
||||
url = f"{self.base_url}/v1/execute"
|
||||
|
||||
payload = {
|
||||
"code": code,
|
||||
"timeout_ms": timeout_ms,
|
||||
}
|
||||
|
||||
if stdin is not None:
|
||||
payload["stdin"] = stdin
|
||||
|
||||
if files:
|
||||
payload["files"] = files
|
||||
|
||||
response = self.session.post(url, json=payload, timeout=timeout_ms / 1000 + 10)
|
||||
response.raise_for_status()
|
||||
|
||||
return ExecuteResponse(**response.json())
|
||||
|
||||
def upload_file(self, file_content: bytes, filename: str) -> str:
|
||||
"""Upload file to Code Interpreter and return file_id"""
|
||||
url = f"{self.base_url}/v1/files"
|
||||
|
||||
files = {"file": (filename, file_content)}
|
||||
response = self.session.post(url, files=files, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()["file_id"]
|
||||
|
||||
def download_file(self, file_id: str) -> bytes:
|
||||
"""Download file from Code Interpreter"""
|
||||
url = f"{self.base_url}/v1/files/{file_id}"
|
||||
|
||||
response = self.session.get(url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.content
|
||||
|
||||
def delete_file(self, file_id: str) -> None:
|
||||
"""Delete file from Code Interpreter"""
|
||||
url = f"{self.base_url}/v1/files/{file_id}"
|
||||
|
||||
response = self.session.delete(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
@@ -1,17 +1,40 @@
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.utils import build_full_frontend_file_url
|
||||
from onyx.file_store.utils import get_default_file_store
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import RunContextWrapper
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
ExecuteResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import FileInput
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.tools.tool_result_models import LlmPythonExecutionResult
|
||||
from onyx.tools.tool_result_models import PythonExecutionFile
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -23,6 +46,262 @@ _GENERIC_ERROR_MESSAGE = (
|
||||
)
|
||||
|
||||
|
||||
def _truncate_output(output: str, max_length: int, label: str = "output") -> str:
|
||||
"""
|
||||
Truncate output string to max_length and append truncation message if needed.
|
||||
|
||||
Args:
|
||||
output: The original output string to truncate
|
||||
max_length: Maximum length before truncation
|
||||
label: Label for logging (e.g., "stdout", "stderr")
|
||||
|
||||
Returns:
|
||||
Truncated string with truncation message appended if truncated
|
||||
"""
|
||||
truncated = output[:max_length]
|
||||
if len(output) > max_length:
|
||||
truncated += (
|
||||
"\n... [output truncated, "
|
||||
f"{len(output) - max_length} "
|
||||
"characters omitted]"
|
||||
)
|
||||
logger.debug(f"Truncated {label}: {truncated}")
|
||||
return truncated
|
||||
|
||||
|
||||
def _combine_outputs(stdout: str, stderr: str) -> str:
|
||||
"""
|
||||
Combine stdout and stderr into a single string if both exist.
|
||||
|
||||
Args:
|
||||
stdout: Standard output string
|
||||
stderr: Standard error string
|
||||
|
||||
Returns:
|
||||
Combined output string with labels if both exist, or the non-empty one
|
||||
if only one exists, or empty string if both are empty
|
||||
"""
|
||||
has_stdout = bool(stdout)
|
||||
has_stderr = bool(stderr)
|
||||
|
||||
if has_stdout and has_stderr:
|
||||
return f"stdout:\n\n{stdout}\n\nstderr:\n\n{stderr}"
|
||||
elif has_stdout:
|
||||
return stdout
|
||||
elif has_stderr:
|
||||
return stderr
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
@tool_accounting
|
||||
def _python_execution_core(
|
||||
run_context: RunContextWrapper[Any],
|
||||
code: str,
|
||||
client: CodeInterpreterClient,
|
||||
tool_id: int,
|
||||
) -> LlmPythonExecutionResult:
|
||||
"""Core Python execution logic"""
|
||||
index = run_context.context.current_run_step
|
||||
emitter = run_context.context.run_dependencies.emitter
|
||||
|
||||
# Emit start event
|
||||
emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=PythonToolStart(code=code),
|
||||
)
|
||||
)
|
||||
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan="Executing Python code",
|
||||
purpose="Running Python code",
|
||||
reasoning="Executing provided Python code in secure environment",
|
||||
)
|
||||
)
|
||||
|
||||
# Get all files from chat context and upload to Code Interpreter
|
||||
files_to_stage: list[FileInput] = []
|
||||
file_store = get_default_file_store()
|
||||
|
||||
# Access chat files directly from context (available after Step 0 changes)
|
||||
chat_files = run_context.context.chat_files
|
||||
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
# Use file content already loaded in memory
|
||||
file_content = chat_file.content
|
||||
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(file_content, file_name)
|
||||
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
|
||||
logger.info(f"Staged file for Python execution: {file_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stage file {file_name}: {e}")
|
||||
|
||||
try:
|
||||
logger.debug(f"Executing code: {code}")
|
||||
|
||||
# Execute code with fixed timeout
|
||||
response: ExecuteResponse = client.execute(
|
||||
code=code,
|
||||
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
files=files_to_stage or None,
|
||||
)
|
||||
|
||||
# Truncate output for LLM consumption
|
||||
truncated_stdout = _truncate_output(
|
||||
response.stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
)
|
||||
truncated_stderr = _truncate_output(
|
||||
response.stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
)
|
||||
|
||||
# Handle generated files
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
generated_file_ids: list[str] = []
|
||||
file_ids_to_cleanup: list[str] = []
|
||||
|
||||
for workspace_file in response.files:
|
||||
if workspace_file.kind != "file" or not workspace_file.file_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Download file from Code Interpreter
|
||||
file_content = client.download_file(workspace_file.file_id)
|
||||
|
||||
# Determine MIME type from file extension
|
||||
filename = workspace_file.path.split("/")[-1]
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
# Default to binary if we can't determine the type
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
# Save to Onyx file store directly
|
||||
onyx_file_id = file_store.save_file(
|
||||
content=BytesIO(file_content),
|
||||
display_name=filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=mime_type,
|
||||
)
|
||||
|
||||
generated_files.append(
|
||||
PythonExecutionFile(
|
||||
filename=filename,
|
||||
file_link=build_full_frontend_file_url(onyx_file_id),
|
||||
)
|
||||
)
|
||||
generated_file_ids.append(onyx_file_id)
|
||||
|
||||
# Mark for cleanup
|
||||
file_ids_to_cleanup.append(workspace_file.file_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to handle generated file {workspace_file.path}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup Code Interpreter files (both generated and staged input files)
|
||||
for ci_file_id in file_ids_to_cleanup:
|
||||
try:
|
||||
client.delete_file(ci_file_id)
|
||||
except Exception as e:
|
||||
# TODO: add TTL on code interpreter files themselves so they are automatically
|
||||
# cleaned up after some time.
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
# TODO: add TTL on code interpreter files themselves so they are automatically
|
||||
# cleaned up after some time.
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
|
||||
# Emit delta with stdout/stderr and generated files
|
||||
emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=PythonToolDelta(
|
||||
type="python_tool_delta",
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
file_ids=generated_file_ids,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Build result with truncated output
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
exit_code=response.exit_code,
|
||||
timed_out=response.timed_out,
|
||||
generated_files=generated_files,
|
||||
error=None if response.exit_code == 0 else truncated_stderr,
|
||||
)
|
||||
|
||||
# Store in iteration answer
|
||||
run_context.context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=PythonTool.__name__,
|
||||
tool_id=tool_id,
|
||||
iteration_nr=index,
|
||||
parallelization_nr=0,
|
||||
question="Execute Python code",
|
||||
reasoning="Executing Python code in secure environment",
|
||||
answer=_combine_outputs(truncated_stdout, truncated_stderr),
|
||||
cited_documents={},
|
||||
file_ids=generated_file_ids,
|
||||
additional_data={
|
||||
"stdout": truncated_stdout,
|
||||
"stderr": truncated_stderr,
|
||||
"code": code,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Python execution failed: {e}")
|
||||
error_msg = str(e)
|
||||
|
||||
# Emit error delta
|
||||
emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=PythonToolDelta(
|
||||
type="python_tool_delta",
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
file_ids=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Return error result
|
||||
return LlmPythonExecutionResult(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
exit_code=-1,
|
||||
timed_out=False,
|
||||
generated_files=[],
|
||||
error=error_msg,
|
||||
)
|
||||
|
||||
|
||||
class PythonTool(Tool[None]):
|
||||
"""
|
||||
Wrapper class for Python code execution tool.
|
||||
@@ -32,7 +311,7 @@ class PythonTool(Tool[None]):
|
||||
"""
|
||||
|
||||
_NAME = "python"
|
||||
_DESCRIPTION = "Execute Python code in a secure, isolated environment. Never call this tool directly."
|
||||
_DESCRIPTION = "Execute Python code in a secure, isolated environment."
|
||||
# in the UI, call it `Code Interpreter` since this is a well known term for this tool
|
||||
_DISPLAY_NAME = "Code Interpreter"
|
||||
|
||||
@@ -114,9 +393,23 @@ class PythonTool(Tool[None]):
|
||||
run_context: RunContextWrapper[Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Not supported - Python tool is only used via v2 agent framework."""
|
||||
raise ValueError(_GENERIC_ERROR_MESSAGE)
|
||||
) -> str:
|
||||
"""Run Python code execution via the v2 implementation.
|
||||
|
||||
Returns:
|
||||
JSON string containing execution results (stdout, stderr, exit code, files)
|
||||
"""
|
||||
code = kwargs.get("code")
|
||||
if not code:
|
||||
raise ValueError("code is required for python execution")
|
||||
|
||||
# Create client and call the core implementation
|
||||
client = CodeInterpreterClient()
|
||||
result = _python_execution_core(run_context, code, client, self._id)
|
||||
|
||||
# Serialize and return
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
return adapter.dump_json(result).decode()
|
||||
|
||||
def run(
|
||||
self, override_kwargs: None = None, **llm_kwargs: str
|
||||
|
||||
@@ -6,11 +6,16 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ContextualPruningConfig
|
||||
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import PromptConfig
|
||||
@@ -19,6 +24,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from onyx.chat.prune_and_merge import prune_and_merge_sections
|
||||
from onyx.chat.prune_and_merge import prune_sections
|
||||
from onyx.chat.stop_signal_checker import is_connected
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -35,6 +41,7 @@ from onyx.context.search.pipeline import SearchPipeline
|
||||
from onyx.context.search.pipeline import section_relevance_list_impl
|
||||
from onyx.db.connector import check_connectors_exist
|
||||
from onyx.db.connector import check_federated_connectors_exist
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -42,6 +49,9 @@ from onyx.llm.models import PreviousMessage
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
@@ -55,6 +65,8 @@ from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.tools.tool_result_models import LlmInternalSearchResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -86,7 +98,7 @@ web pages. If very ambiguious, prioritize internal search or call both tools.
|
||||
|
||||
|
||||
class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
_NAME = "run_search"
|
||||
_NAME = "internal_search"
|
||||
_DISPLAY_NAME = "Internal Search"
|
||||
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
|
||||
|
||||
@@ -112,6 +124,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
bypass_acl: bool = False,
|
||||
rerank_settings: RerankingDetails | None = None,
|
||||
slack_context: SlackContext | None = None,
|
||||
# just doing this for now since we lack dependency injection
|
||||
search_pipeline_override_for_testing: SearchPipeline | None = None,
|
||||
) -> None:
|
||||
self.user = user
|
||||
self.persona = persona
|
||||
@@ -177,6 +191,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
)
|
||||
|
||||
self._id = tool_id
|
||||
self.search_pipeline_override_for_testing = search_pipeline_override_for_testing
|
||||
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
@@ -260,13 +275,136 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
|
||||
"""Actual tool execution"""
|
||||
|
||||
@tool_accounting
|
||||
def run_v2(
|
||||
self,
|
||||
run_context: RunContextWrapper[Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
raise NotImplementedError("SearchTool.run_v2 is not implemented.")
|
||||
) -> str:
|
||||
index = run_context.context.current_run_step
|
||||
query = kwargs[QUERY_FIELD]
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SearchToolStart(
|
||||
type="internal_search_tool_start", is_internet_search=False
|
||||
),
|
||||
)
|
||||
)
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SearchToolDelta(
|
||||
type="internal_search_tool_delta", queries=[query], documents=[]
|
||||
),
|
||||
)
|
||||
)
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan="plan",
|
||||
purpose="Searching internally for information",
|
||||
reasoning=f"I am now using Internal Search to gather information on {query}",
|
||||
)
|
||||
)
|
||||
|
||||
retrieved_sections: list[InferenceSection] = []
|
||||
|
||||
with get_session_with_current_tenant() as search_db_session:
|
||||
for tool_response in self.run(
|
||||
query=query,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=search_db_session,
|
||||
skip_query_analysis=True,
|
||||
original_query=query,
|
||||
),
|
||||
):
|
||||
if not is_connected(
|
||||
run_context.context.chat_session_id,
|
||||
run_context.context.run_dependencies.redis_client,
|
||||
):
|
||||
break
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(
|
||||
SearchResponseSummary, tool_response.response
|
||||
)
|
||||
retrieved_sections = search_response_summary.top_sections
|
||||
break
|
||||
|
||||
# Aggregate all results from all queries
|
||||
# Use the current input token count from context for pruning
|
||||
# This includes system prompt, history, user message, and any agent turns so far
|
||||
existing_input_tokens = run_context.context.current_input_tokens
|
||||
|
||||
pruned_sections: list[InferenceSection] = prune_and_merge_sections(
|
||||
sections=retrieved_sections,
|
||||
section_relevance_list=None,
|
||||
llm_config=self.llm.config,
|
||||
existing_input_tokens=existing_input_tokens,
|
||||
contextual_pruning_config=self.contextual_pruning_config,
|
||||
)
|
||||
|
||||
search_results_for_query = [
|
||||
LlmInternalSearchResult(
|
||||
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
|
||||
title=section.center_chunk.semantic_identifier,
|
||||
excerpt=section.combined_content,
|
||||
metadata=section.center_chunk.metadata,
|
||||
unique_identifier_to_strip_away=section.center_chunk.document_id,
|
||||
)
|
||||
for section in pruned_sections
|
||||
]
|
||||
|
||||
from onyx.chat.turn.models import FetchedDocumentCacheEntry
|
||||
|
||||
for section in pruned_sections:
|
||||
unique_id = section.center_chunk.document_id
|
||||
if unique_id not in run_context.context.fetched_documents_cache:
|
||||
run_context.context.fetched_documents_cache[unique_id] = (
|
||||
FetchedDocumentCacheEntry(
|
||||
inference_section=section,
|
||||
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
|
||||
)
|
||||
)
|
||||
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SearchToolDelta(
|
||||
type="internal_search_tool_delta",
|
||||
queries=[],
|
||||
documents=convert_inference_sections_to_search_docs(
|
||||
pruned_sections, is_internet=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
run_context.context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=SearchTool.__name__,
|
||||
tool_id=self.id,
|
||||
iteration_nr=index,
|
||||
parallelization_nr=0,
|
||||
question=query[0] if query else "",
|
||||
reasoning=f"I am now using Internal Search to gather information on {query[0] if query else ''}",
|
||||
answer="",
|
||||
cited_documents={
|
||||
i: inference_section
|
||||
for i, inference_section in enumerate(pruned_sections)
|
||||
},
|
||||
queries=[query],
|
||||
)
|
||||
)
|
||||
# Set flag to include citation requirements since we retrieved documents
|
||||
run_context.context.should_cite_documents = (
|
||||
run_context.context.should_cite_documents or bool(pruned_sections)
|
||||
)
|
||||
|
||||
adapter = TypeAdapter(list[LlmInternalSearchResult])
|
||||
return adapter.dump_json(search_results_for_query).decode()
|
||||
|
||||
def _build_response_for_specified_sections(
|
||||
self, query: str
|
||||
@@ -402,7 +540,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
if kg_chunk_id_zero_only:
|
||||
retrieval_options.filters.kg_chunk_id_zero_only = kg_chunk_id_zero_only
|
||||
|
||||
search_pipeline = SearchPipeline(
|
||||
search_pipeline = self.search_pipeline_override_for_testing or SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
evaluation_type=(
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebContentProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_content_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
dummy_inference_section_from_internet_content,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
truncate_search_result_content,
|
||||
)
|
||||
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.turn.models import FetchedDocumentCacheEntry
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.web_search import fetch_active_web_search_provider
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.server.query_and_chat.streaming_models import FetchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SavedSearchDoc
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import RunContextWrapper
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.tools.tool_result_models import LlmOpenUrlResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_OPEN_URL_GENERIC_ERROR_MESSAGE = (
|
||||
"OpenUrlTool should only be used by the Deep Research Agent, not via tool calling."
|
||||
)
|
||||
|
||||
|
||||
class OpenUrlTool(Tool[None]):
|
||||
_NAME = "open_url"
|
||||
_DESCRIPTION = "Fetch and extract full content from web pages."
|
||||
_DISPLAY_NAME = "Open URL"
|
||||
|
||||
def __init__(self, tool_id: int) -> None:
|
||||
self._id = tool_id
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
"""Available only if an active web search provider is configured in the database."""
|
||||
with get_session_with_current_tenant() as session:
|
||||
provider = fetch_active_web_search_provider(session)
|
||||
return provider is not None
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"urls": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "URLs to fetch content from",
|
||||
},
|
||||
},
|
||||
"required": ["urls"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
raise ValueError(_OPEN_URL_GENERIC_ERROR_MESSAGE)
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
raise ValueError(_OPEN_URL_GENERIC_ERROR_MESSAGE)
|
||||
|
||||
@tool_accounting
|
||||
def _open_url_core(
|
||||
self,
|
||||
run_context: RunContextWrapper[Any],
|
||||
urls: Sequence[str],
|
||||
content_provider: WebContentProvider,
|
||||
) -> list[LlmOpenUrlResult]:
|
||||
# TODO: Find better way to track index that isn't so implicit
|
||||
# based on number of tool calls
|
||||
index = run_context.context.current_run_step
|
||||
|
||||
# Create SavedSearchDoc objects from URLs for the FetchToolStart event
|
||||
saved_search_docs = [SavedSearchDoc.from_url(url) for url in urls]
|
||||
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=FetchToolStart(
|
||||
type="fetch_tool_start", documents=saved_search_docs
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
docs = content_provider.contents(urls)
|
||||
results = [
|
||||
LlmOpenUrlResult(
|
||||
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
|
||||
content=truncate_search_result_content(doc.full_content),
|
||||
unique_identifier_to_strip_away=doc.link,
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
for doc in docs:
|
||||
cache = run_context.context.fetched_documents_cache
|
||||
entry = cache.setdefault(
|
||||
doc.link,
|
||||
FetchedDocumentCacheEntry(
|
||||
inference_section=dummy_inference_section_from_internet_content(
|
||||
doc
|
||||
),
|
||||
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
|
||||
),
|
||||
)
|
||||
entry.inference_section = dummy_inference_section_from_internet_content(doc)
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan="plan",
|
||||
purpose="Fetching content from URLs",
|
||||
reasoning=f"I am now using Web Fetch to gather information on {', '.join(urls)}",
|
||||
)
|
||||
)
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
|
||||
run_context.context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
# TODO: For now, we're using the web_search_tool_name since the web_fetch_tool_name is not a built-in tool
|
||||
tool=WebSearchTool.__name__,
|
||||
tool_id=get_tool_by_name(
|
||||
WebSearchTool.__name__,
|
||||
run_context.context.run_dependencies.db_session,
|
||||
).id,
|
||||
iteration_nr=index,
|
||||
parallelization_nr=0,
|
||||
question=f"Fetch content from URLs: {', '.join(urls)}",
|
||||
reasoning=f"I am now using Web Fetch to gather information on {', '.join(urls)}",
|
||||
answer="",
|
||||
cited_documents={
|
||||
i: dummy_inference_section_from_internet_content(d)
|
||||
for i, d in enumerate(docs)
|
||||
},
|
||||
claims=[],
|
||||
is_web_fetch=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Set flag to include citation requirements since we fetched documents
|
||||
run_context.context.should_cite_documents = True
|
||||
|
||||
return results
|
||||
|
||||
def run_v2(
|
||||
self,
|
||||
run_context: RunContextWrapper[Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run open_url using the v2 implementation"""
|
||||
urls = kwargs.get("urls", [])
|
||||
if not urls:
|
||||
raise ValueError("urls parameter is required")
|
||||
|
||||
content_provider = get_default_content_provider()
|
||||
if content_provider is None:
|
||||
raise ValueError("No web content provider found")
|
||||
|
||||
retrieved_docs = self._open_url_core(run_context, urls, content_provider) # type: ignore[arg-type]
|
||||
adapter = TypeAdapter(list[LlmOpenUrlResult])
|
||||
return adapter.dump_json(retrieved_docs).decode()
|
||||
|
||||
def run(
|
||||
self, override_kwargs: None = None, **llm_kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
raise ValueError(_OPEN_URL_GENERIC_ERROR_MESSAGE)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
raise ValueError(_OPEN_URL_GENERIC_ERROR_MESSAGE)
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
tool_call_summary: ToolCallSummary,
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> AnswerPromptBuilder:
|
||||
raise ValueError(_OPEN_URL_GENERIC_ERROR_MESSAGE)
|
||||
@@ -1,32 +1,59 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
dummy_inference_section_from_internet_search_result,
|
||||
)
|
||||
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.turn.models import FetchedDocumentCacheEntry
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.web_search import fetch_active_web_search_provider
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import RunContextWrapper
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.tools.tool_result_models import LlmWebSearchResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# TODO: Align on separation of Tools and SubAgents. Right now, we're only keeping this around for backwards compatibility.
|
||||
QUERY_FIELD = "query"
|
||||
QUERY_FIELD = "queries"
|
||||
_GENERIC_ERROR_MESSAGE = "WebSearchTool should only be used by the Deep Research Agent, not via tool calling."
|
||||
_OPEN_URL_GENERIC_ERROR_MESSAGE = (
|
||||
"OpenUrlTool should only be used by the Deep Research Agent, not via tool calling."
|
||||
)
|
||||
|
||||
|
||||
class WebSearchTool(Tool[None]):
|
||||
_NAME = "run_web_search"
|
||||
_DESCRIPTION = "Search the web for information. Never call this tool."
|
||||
_NAME = "web_search"
|
||||
_DESCRIPTION = "Search the web for information."
|
||||
_DISPLAY_NAME = "Web Search"
|
||||
|
||||
def __init__(self, tool_id: int) -> None:
|
||||
@@ -66,7 +93,8 @@ class WebSearchTool(Tool[None]):
|
||||
"type": "object",
|
||||
"properties": {
|
||||
QUERY_FIELD: {
|
||||
"type": "string",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "What to search for",
|
||||
},
|
||||
},
|
||||
@@ -89,13 +117,140 @@ class WebSearchTool(Tool[None]):
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
raise ValueError(_GENERIC_ERROR_MESSAGE)
|
||||
|
||||
@tool_accounting
|
||||
def _web_search_core(
|
||||
self,
|
||||
run_context: RunContextWrapper[Any],
|
||||
queries: list[str],
|
||||
search_provider: WebSearchProvider,
|
||||
) -> list[LlmWebSearchResult]:
|
||||
index = run_context.context.current_run_step
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SearchToolStart(
|
||||
type="internal_search_tool_start", is_internet_search=True
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Emit a packet in the beginning to communicate queries to the frontend
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SearchToolDelta(
|
||||
type="internal_search_tool_delta",
|
||||
queries=queries,
|
||||
documents=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
queries_str = ", ".join(queries)
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan="plan",
|
||||
purpose="Searching the web for information",
|
||||
reasoning=f"I am now using Web Search to gather information on {queries_str}",
|
||||
)
|
||||
)
|
||||
|
||||
# Search all queries in parallel
|
||||
function_calls = [
|
||||
FunctionCall(func=search_provider.search, args=(query,))
|
||||
for query in queries
|
||||
]
|
||||
search_results_dict = run_functions_in_parallel(function_calls)
|
||||
|
||||
# Aggregate all results from all queries
|
||||
all_hits: list[WebSearchResult] = []
|
||||
for result_id in search_results_dict:
|
||||
hits = search_results_dict[result_id]
|
||||
if hits:
|
||||
all_hits.extend(hits)
|
||||
|
||||
inference_sections = [
|
||||
dummy_inference_section_from_internet_search_result(r) for r in all_hits
|
||||
]
|
||||
|
||||
from onyx.agents.agent_search.dr.utils import (
|
||||
convert_inference_sections_to_search_docs,
|
||||
)
|
||||
|
||||
saved_search_docs = convert_inference_sections_to_search_docs(
|
||||
inference_sections, is_internet=True
|
||||
)
|
||||
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SearchToolDelta(
|
||||
type="internal_search_tool_delta",
|
||||
queries=queries,
|
||||
documents=saved_search_docs,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
results = []
|
||||
for r in all_hits:
|
||||
results.append(
|
||||
LlmWebSearchResult(
|
||||
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
|
||||
url=r.link,
|
||||
title=r.title,
|
||||
snippet=r.snippet or "",
|
||||
unique_identifier_to_strip_away=r.link,
|
||||
)
|
||||
)
|
||||
if r.link not in run_context.context.fetched_documents_cache:
|
||||
run_context.context.fetched_documents_cache[r.link] = (
|
||||
FetchedDocumentCacheEntry(
|
||||
inference_section=dummy_inference_section_from_internet_search_result(
|
||||
r
|
||||
),
|
||||
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
|
||||
)
|
||||
)
|
||||
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
|
||||
run_context.context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=WebSearchTool.__name__,
|
||||
tool_id=get_tool_by_name(
|
||||
WebSearchTool.__name__,
|
||||
run_context.context.run_dependencies.db_session,
|
||||
).id,
|
||||
iteration_nr=index,
|
||||
parallelization_nr=0,
|
||||
question=queries_str,
|
||||
reasoning=f"I am now using Web Search to gather information on {queries_str}",
|
||||
answer="",
|
||||
cited_documents={
|
||||
i: inference_section
|
||||
for i, inference_section in enumerate(inference_sections)
|
||||
},
|
||||
claims=[],
|
||||
queries=queries,
|
||||
)
|
||||
)
|
||||
run_context.context.should_cite_documents = True
|
||||
return results
|
||||
|
||||
def run_v2(
|
||||
self,
|
||||
run_context: RunContextWrapper[Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
raise NotImplementedError("WebSearchTool.run_v2 is not implemented.")
|
||||
queries: list[str],
|
||||
) -> str:
|
||||
search_provider = get_default_provider()
|
||||
if search_provider is None:
|
||||
raise ValueError("No search provider found")
|
||||
|
||||
response = self._web_search_core(run_context, queries, search_provider) # type: ignore[arg-type]
|
||||
adapter = TypeAdapter(list[LlmWebSearchResult])
|
||||
return adapter.dump_json(response).decode()
|
||||
|
||||
def run(
|
||||
self, override_kwargs: None = None, **llm_kwargs: str
|
||||
|
||||
@@ -23,15 +23,15 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
LlmInternalSearchResult,
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import (
|
||||
tool_accounting_function,
|
||||
)
|
||||
from onyx.tools.tool_result_models import LlmInternalSearchResult
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
|
||||
@tool_accounting
|
||||
@tool_accounting_function
|
||||
def _internal_search_core(
|
||||
run_context: RunContextWrapper[ChatTurnContext],
|
||||
queries: list[str],
|
||||
|
||||
@@ -17,17 +17,19 @@ from onyx.file_store.utils import get_default_file_store
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations_v2.code_interpreter_client import (
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.code_interpreter_client import ExecuteResponse
|
||||
from onyx.tools.tool_implementations_v2.code_interpreter_client import FileInput
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
LlmPythonExecutionResult,
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
ExecuteResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import PythonExecutionFile
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import FileInput
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import (
|
||||
tool_accounting_function,
|
||||
)
|
||||
from onyx.tools.tool_result_models import LlmPythonExecutionResult
|
||||
from onyx.tools.tool_result_models import PythonExecutionFile
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -81,7 +83,7 @@ def _combine_outputs(stdout: str, stderr: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
@tool_accounting
|
||||
@tool_accounting_function
|
||||
def _python_execution_core(
|
||||
run_context: RunContextWrapper[ChatTurnContext],
|
||||
code: str,
|
||||
|
||||
@@ -31,6 +31,76 @@ def tool_accounting(func: F) -> F:
|
||||
The decorated function with tool accounting functionality.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(
|
||||
selfobj: Any,
|
||||
run_context: RunContextWrapper[ChatTurnContext],
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
# Increment current_run_step at the beginning
|
||||
run_context.context.current_run_step += 1
|
||||
|
||||
try:
|
||||
# Call the original function (pass selfobj for class methods)
|
||||
result = func(selfobj, run_context, *args, **kwargs)
|
||||
|
||||
# If it's a coroutine, we need to handle it differently
|
||||
if inspect.iscoroutine(result):
|
||||
# For async functions, we need to return a coroutine that handles the cleanup
|
||||
async def async_wrapper() -> Any:
|
||||
try:
|
||||
return await result
|
||||
finally:
|
||||
_emit_section_end(run_context)
|
||||
|
||||
return async_wrapper()
|
||||
else:
|
||||
# For sync functions, emit cleanup immediately
|
||||
_emit_section_end(run_context)
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
# Always emit cleanup even if an exception occurred
|
||||
_emit_section_end(run_context)
|
||||
raise
|
||||
|
||||
return cast(F, wrapper)
|
||||
|
||||
|
||||
def _emit_section_end(run_context: RunContextWrapper[ChatTurnContext]) -> None:
|
||||
"""Helper function to emit section end packet and increment current_run_step."""
|
||||
index = run_context.context.current_run_step
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SectionEnd(
|
||||
type="section_end",
|
||||
),
|
||||
)
|
||||
)
|
||||
run_context.context.current_run_step += 1
|
||||
|
||||
|
||||
def tool_accounting_function(func: F) -> F:
|
||||
"""
|
||||
Decorator for standalone functions (not methods) that adds tool accounting functionality.
|
||||
|
||||
Use this for standalone functions like _internal_search_core.
|
||||
Use tool_accounting for class methods.
|
||||
|
||||
This decorator:
|
||||
1. Increments the current_run_step index at the beginning
|
||||
2. Emits a section end packet and increments current_run_step at the end
|
||||
3. Ensures the cleanup happens even if an exception occurs
|
||||
|
||||
Args:
|
||||
func: The function to decorate. Must take a RunContextWrapper[ChatTurnContext] as first argument.
|
||||
|
||||
Returns:
|
||||
The decorated function with tool accounting functionality.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(
|
||||
run_context: RunContextWrapper[ChatTurnContext], *args: Any, **kwargs: Any
|
||||
@@ -63,17 +133,3 @@ def tool_accounting(func: F) -> F:
|
||||
raise
|
||||
|
||||
return cast(F, wrapper)
|
||||
|
||||
|
||||
def _emit_section_end(run_context: RunContextWrapper[ChatTurnContext]) -> None:
|
||||
"""Helper function to emit section end packet and increment current_run_step."""
|
||||
index = run_context.context.current_run_step
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SectionEnd(
|
||||
type="section_end",
|
||||
),
|
||||
)
|
||||
)
|
||||
run_context.context.current_run_step += 1
|
||||
|
||||
@@ -40,13 +40,15 @@ from onyx.server.query_and_chat.streaming_models import SavedSearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import LlmOpenUrlResult
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import LlmWebSearchResult
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import (
|
||||
tool_accounting_function,
|
||||
)
|
||||
from onyx.tools.tool_result_models import LlmOpenUrlResult
|
||||
from onyx.tools.tool_result_models import LlmWebSearchResult
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
|
||||
@tool_accounting
|
||||
@tool_accounting_function
|
||||
def _web_search_core(
|
||||
run_context: RunContextWrapper[ChatTurnContext],
|
||||
queries: list[str],
|
||||
@@ -192,7 +194,7 @@ changing or evolving.
|
||||
"""
|
||||
|
||||
|
||||
@tool_accounting
|
||||
@tool_accounting_function
|
||||
def _open_url_core(
|
||||
run_context: RunContextWrapper[ChatTurnContext],
|
||||
urls: Sequence[str],
|
||||
|
||||
59
backend/onyx/tools/tool_result_models.py
Normal file
59
backend/onyx/tools/tool_result_models.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Base models for tool results with citation support."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseCiteableToolResult(BaseModel):
|
||||
"""Base class for tool results that can be cited."""
|
||||
|
||||
document_citation_number: int
|
||||
unique_identifier_to_strip_away: str | None = None
|
||||
type: str
|
||||
|
||||
|
||||
class LlmInternalSearchResult(BaseCiteableToolResult):
|
||||
"""Result from an internal search query"""
|
||||
|
||||
type: Literal["internal_search"] = "internal_search"
|
||||
title: str
|
||||
excerpt: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class LlmWebSearchResult(BaseCiteableToolResult):
|
||||
"""Result from a web search query"""
|
||||
|
||||
type: Literal["web_search"] = "web_search"
|
||||
url: str
|
||||
title: str
|
||||
snippet: str
|
||||
|
||||
|
||||
class LlmOpenUrlResult(BaseCiteableToolResult):
|
||||
"""Result from opening/fetching a URL"""
|
||||
|
||||
type: Literal["open_url"] = "open_url"
|
||||
content: str
|
||||
|
||||
|
||||
class PythonExecutionFile(BaseModel):
|
||||
"""File generated during Python execution"""
|
||||
|
||||
filename: str
|
||||
file_link: str
|
||||
|
||||
|
||||
class LlmPythonExecutionResult(BaseModel):
|
||||
"""Result from Python code execution"""
|
||||
|
||||
type: Literal["python_execution"] = "python_execution"
|
||||
|
||||
stdout: str
|
||||
stderr: str
|
||||
exit_code: int | None
|
||||
timed_out: bool
|
||||
generated_files: list[PythonExecutionFile]
|
||||
error: str | None = None
|
||||
@@ -30,15 +30,12 @@ from onyx.file_store.utils import get_default_file_store
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations_v2.code_interpreter_client import (
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.python import _python_execution_core
|
||||
from onyx.tools.tool_implementations_v2.python import python
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
LlmPythonExecutionResult,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.python_tool import _python_execution_core
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_result_models import LlmPythonExecutionResult
|
||||
|
||||
|
||||
# Apply initialize_file_store fixture to all tests in this module
|
||||
@@ -459,7 +456,7 @@ def test_python_function_tool_wrapper(
|
||||
mock_client_class.return_value = code_interpreter_client
|
||||
|
||||
# Call the function tool wrapper
|
||||
result_coro = python.on_invoke_tool(mock_run_context, json.dumps({"code": code})) # type: ignore
|
||||
result_coro = PythonTool.on_invoke_tool(mock_run_context, json.dumps({"code": code})) # type: ignore
|
||||
result_json: str = asyncio.run(result_coro) # type: ignore
|
||||
|
||||
# Verify result is JSON string
|
||||
|
||||
@@ -300,3 +300,70 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
|
||||
|
||||
def test_exclude_reasoning_config_for_anthropic_with_function_tool_choice() -> None:
|
||||
"""Test that reasoning config is excluded for Anthropic models when using function tool choice."""
|
||||
# Create an Anthropic LLM
|
||||
anthropic_llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
timeout=30,
|
||||
model_provider="anthropic",
|
||||
model_name="claude-3-5-sonnet-20241022",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
# Mock the litellm.completion function
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="msg-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="tool_use",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content=None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
litellm.ChatCompletionMessageToolCall(
|
||||
id="call_1",
|
||||
function=LiteLLMFunction(
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
usage=litellm.Usage(
|
||||
prompt_tokens=50, completion_tokens=30, total_tokens=80
|
||||
),
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
messages = [HumanMessage(content="Test message")]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search for information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Call with a specific function tool choice
|
||||
anthropic_llm.invoke(messages, tools, tool_choice="search")
|
||||
|
||||
# Verify that reasoning_effort and thinking are NOT passed to litellm.completion
|
||||
call_kwargs = mock_completion.call_args[1]
|
||||
assert "reasoning_effort" not in call_kwargs
|
||||
assert "thinking" not in call_kwargs
|
||||
|
||||
@@ -4,7 +4,6 @@ from tests.unit.onyx.chat.turn.utils import chat_turn_context
|
||||
from tests.unit.onyx.chat.turn.utils import chat_turn_dependencies
|
||||
from tests.unit.onyx.chat.turn.utils import fake_db_session
|
||||
from tests.unit.onyx.chat.turn.utils import fake_llm
|
||||
from tests.unit.onyx.chat.turn.utils import fake_model
|
||||
from tests.unit.onyx.chat.turn.utils import fake_redis_client
|
||||
from tests.unit.onyx.chat.turn.utils import fake_tools
|
||||
|
||||
@@ -13,7 +12,6 @@ __all__ = [
|
||||
"chat_turn_dependencies",
|
||||
"fake_db_session",
|
||||
"fake_llm",
|
||||
"fake_model",
|
||||
"fake_redis_client",
|
||||
"fake_tools",
|
||||
]
|
||||
|
||||
525
backend/tests/unit/onyx/tools/test_custom_and_mcp_tool_run_v2.py
Normal file
525
backend/tests/unit/onyx/tools/test_custom_and_mcp_tool_run_v2.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""Tests for CustomTool and MCPTool run_v2() methods using dependency injection.
|
||||
|
||||
This test module focuses on testing the run_v2() methods for CustomTool and MCPTool,
|
||||
adapted from test_adapter_v1_to_v2.py but directly testing the tool implementations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from agents import RunContextWrapper
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
from onyx.db.models import MCPServer
|
||||
from onyx.tools.models import DynamicSchemaInfo
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.mcp.mcp_tool import MCPTool
|
||||
from tests.unit.onyx.chat.turn.utils import FakeRedis
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fake Classes for Dependency Injection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_fake_database_session() -> Any:
|
||||
"""Create a fake SQLAlchemy Session for testing"""
|
||||
from unittest.mock import Mock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
fake_session = Mock(spec=Session)
|
||||
fake_session.committed = False
|
||||
fake_session.rolled_back = False
|
||||
|
||||
def mock_commit() -> None:
|
||||
fake_session.committed = True
|
||||
|
||||
def mock_rollback() -> None:
|
||||
fake_session.rolled_back = True
|
||||
|
||||
fake_session.commit = mock_commit
|
||||
fake_session.rollback = mock_rollback
|
||||
fake_session.add = Mock()
|
||||
fake_session.flush = Mock()
|
||||
|
||||
return fake_session
|
||||
|
||||
|
||||
class FakeEmitter:
|
||||
"""Fake emitter for testing that records all emitted packets"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.packet_history: list[Any] = []
|
||||
|
||||
def emit(self, packet: Any) -> None:
|
||||
self.packet_history.append(packet)
|
||||
|
||||
|
||||
class FakeRunDependencies:
|
||||
"""Fake run dependencies for testing"""
|
||||
|
||||
def __init__(self, db_session: Any, redis_client: FakeRedis, tool: Any) -> None:
|
||||
self.db_session = db_session
|
||||
self.redis_client = redis_client
|
||||
self.emitter = FakeEmitter()
|
||||
self.tools = [tool]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_fake_run_context(
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
db_session: Any,
|
||||
redis_client: FakeRedis,
|
||||
tool: Any,
|
||||
current_run_step: int = 0,
|
||||
) -> RunContextWrapper[ChatTurnContext]:
|
||||
"""Create a fake run context for testing"""
|
||||
run_dependencies = FakeRunDependencies(db_session, redis_client, tool)
|
||||
|
||||
context = ChatTurnContext(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=message_id,
|
||||
run_dependencies=run_dependencies, # type: ignore
|
||||
)
|
||||
context.current_run_step = current_run_step
|
||||
|
||||
return RunContextWrapper(context=context)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_session_id() -> UUID:
|
||||
"""Fixture providing fake chat session ID."""
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def message_id() -> int:
|
||||
"""Fixture providing fake message ID."""
|
||||
return 123
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db_session() -> Any:
|
||||
"""Fixture providing a fake database session."""
|
||||
return create_fake_database_session()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_redis_client() -> FakeRedis:
|
||||
"""Fixture providing a fake Redis client."""
|
||||
return FakeRedis()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_schema() -> dict[str, Any]:
|
||||
"""OpenAPI schema for testing."""
|
||||
return {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"version": "1.0.0",
|
||||
"title": "Test API",
|
||||
"description": "A test API for testing",
|
||||
},
|
||||
"servers": [
|
||||
{"url": "http://localhost:8080/CHAT_SESSION_ID/test/MESSAGE_ID"},
|
||||
],
|
||||
"paths": {
|
||||
"/test/{test_id}": {
|
||||
"GET": {
|
||||
"summary": "Get a test item",
|
||||
"operationId": "getTestItem",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "test_id",
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "string"},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dynamic_schema_info(chat_session_id: UUID, message_id: int) -> DynamicSchemaInfo:
|
||||
"""Dynamic schema info for testing."""
|
||||
return DynamicSchemaInfo(chat_session_id=chat_session_id, message_id=message_id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_tool(
|
||||
openapi_schema: dict[str, Any], dynamic_schema_info: DynamicSchemaInfo
|
||||
) -> CustomTool:
|
||||
"""Custom tool for testing."""
|
||||
tools = build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=-1, # dummy tool id
|
||||
openapi_schema=openapi_schema,
|
||||
dynamic_schema_info=dynamic_schema_info,
|
||||
)
|
||||
return tools[0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server() -> MCPServer:
|
||||
"""MCP server for testing."""
|
||||
return MCPServer(
|
||||
id=1,
|
||||
name="test_mcp_server",
|
||||
server_url="http://localhost:8080/mcp",
|
||||
auth_type=MCPAuthenticationType.NONE,
|
||||
transport=MCPTransport.STREAMABLE_HTTP,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_tool(mcp_server: MCPServer) -> MCPTool:
|
||||
"""MCP tool for testing."""
|
||||
tool_definition = {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string", "description": "The search query"}},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
return MCPTool(
|
||||
tool_id=1,
|
||||
mcp_server=mcp_server,
|
||||
tool_name="search",
|
||||
tool_description="Search for information",
|
||||
tool_definition=tool_definition,
|
||||
connection_config=None,
|
||||
user_email="test@example.com",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Custom Tool Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@patch("onyx.tools.tool_implementations.custom.custom_tool.requests.request")
|
||||
def test_custom_tool_run_v2_basic_invocation(
|
||||
mock_request: MagicMock,
|
||||
custom_tool: CustomTool,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
fake_db_session: Any,
|
||||
fake_redis_client: FakeRedis,
|
||||
dynamic_schema_info: DynamicSchemaInfo,
|
||||
) -> None:
|
||||
"""Test basic functionality of CustomTool.run_v2()."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"id": "456",
|
||||
"name": "Test Item",
|
||||
"status": "active",
|
||||
}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
fake_run_context = create_fake_run_context(
|
||||
chat_session_id, message_id, fake_db_session, fake_redis_client, custom_tool
|
||||
)
|
||||
|
||||
# Act
|
||||
result = custom_tool.run_v2(fake_run_context, test_id="456")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
result_json = json.loads(result)
|
||||
assert result_json["id"] == "456"
|
||||
assert result_json["name"] == "Test Item"
|
||||
assert result_json["status"] == "active"
|
||||
|
||||
# Verify HTTP request was made
|
||||
expected_url = f"http://localhost:8080/{dynamic_schema_info.chat_session_id}/test/{dynamic_schema_info.message_id}/test/456"
|
||||
mock_request.assert_called_once_with("GET", expected_url, json=None, headers={})
|
||||
|
||||
|
||||
@patch("onyx.tools.tool_implementations.custom.custom_tool.requests.request")
|
||||
def test_custom_tool_run_v2_iteration_tracking(
|
||||
mock_request: MagicMock,
|
||||
custom_tool: CustomTool,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
fake_db_session: Any,
|
||||
fake_redis_client: FakeRedis,
|
||||
) -> None:
|
||||
"""Test that IterationInstructions and IterationAnswer are properly added."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
fake_run_context = create_fake_run_context(
|
||||
chat_session_id, message_id, fake_db_session, fake_redis_client, custom_tool
|
||||
)
|
||||
|
||||
# Act
|
||||
custom_tool.run_v2(fake_run_context, test_id="789")
|
||||
|
||||
# Assert - verify IterationInstructions was added
|
||||
assert len(fake_run_context.context.iteration_instructions) == 1
|
||||
iteration_instruction = fake_run_context.context.iteration_instructions[0]
|
||||
assert isinstance(iteration_instruction, IterationInstructions)
|
||||
assert iteration_instruction.iteration_nr == 1
|
||||
assert iteration_instruction.plan == f"Running {custom_tool.name}"
|
||||
assert iteration_instruction.purpose == f"Running {custom_tool.name}"
|
||||
assert iteration_instruction.reasoning == f"Running {custom_tool.name}"
|
||||
|
||||
# Assert - verify IterationAnswer was added
|
||||
assert len(fake_run_context.context.global_iteration_responses) == 1
|
||||
iteration_answer = fake_run_context.context.global_iteration_responses[0]
|
||||
assert isinstance(iteration_answer, IterationAnswer)
|
||||
assert iteration_answer.tool == custom_tool.name
|
||||
assert iteration_answer.tool_id == custom_tool.id
|
||||
assert iteration_answer.iteration_nr == 1
|
||||
assert iteration_answer.parallelization_nr == 0
|
||||
assert iteration_answer.question == '{"test_id": "789"}'
|
||||
assert iteration_answer.reasoning == f"Running {custom_tool.name}"
|
||||
assert iteration_answer.answer == "{'result': 'success'}"
|
||||
assert iteration_answer.cited_documents == {}
|
||||
|
||||
|
||||
@patch("onyx.tools.tool_implementations.custom.custom_tool.requests.request")
|
||||
def test_custom_tool_run_v2_packet_emissions(
|
||||
mock_request: MagicMock,
|
||||
custom_tool: CustomTool,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
fake_db_session: Any,
|
||||
fake_redis_client: FakeRedis,
|
||||
) -> None:
|
||||
"""Test that the correct packets are emitted during tool execution."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = {"test": "data"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
fake_run_context = create_fake_run_context(
|
||||
chat_session_id, message_id, fake_db_session, fake_redis_client, custom_tool
|
||||
)
|
||||
|
||||
# Act
|
||||
custom_tool.run_v2(fake_run_context, test_id="123")
|
||||
|
||||
# Assert - verify emitter captured packets
|
||||
emitter = fake_run_context.context.run_dependencies.emitter
|
||||
# Should have: CustomToolStart, CustomToolDelta
|
||||
assert len(emitter.packet_history) >= 2
|
||||
|
||||
# Check CustomToolStart
|
||||
start_packet = emitter.packet_history[0]
|
||||
assert getattr(start_packet.obj, "type", None) == "custom_tool_start"
|
||||
assert start_packet.obj.tool_name == custom_tool.name
|
||||
|
||||
# Check CustomToolDelta
|
||||
delta_packet = emitter.packet_history[1]
|
||||
assert getattr(delta_packet.obj, "type", None) == "custom_tool_delta"
|
||||
assert delta_packet.obj.tool_name == custom_tool.name
|
||||
|
||||
|
||||
@patch("onyx.tools.tool_implementations.custom.custom_tool.requests.request")
|
||||
@patch("onyx.tools.tool_implementations.custom.custom_tool.get_default_file_store")
|
||||
@patch("uuid.uuid4")
|
||||
def test_custom_tool_run_v2_csv_response_with_file_ids(
|
||||
mock_uuid: MagicMock,
|
||||
mock_file_store: MagicMock,
|
||||
mock_request: MagicMock,
|
||||
custom_tool: CustomTool,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
fake_db_session: Any,
|
||||
fake_redis_client: FakeRedis,
|
||||
) -> None:
|
||||
"""Test that CSV responses with file_ids are handled correctly."""
|
||||
# Arrange
|
||||
mock_uuid.return_value = uuid.UUID("12345678-1234-5678-9abc-123456789012")
|
||||
mock_store_instance = MagicMock()
|
||||
mock_file_store.return_value = mock_store_instance
|
||||
mock_store_instance.save_file.return_value = "csv_file_123"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.headers = {"Content-Type": "text/csv"}
|
||||
mock_response.content = b"name,age,city\nJohn,30,New York\nJane,25,Los Angeles"
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
fake_run_context = create_fake_run_context(
|
||||
chat_session_id,
|
||||
message_id,
|
||||
fake_db_session,
|
||||
fake_redis_client,
|
||||
custom_tool,
|
||||
current_run_step=2,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = custom_tool.run_v2(fake_run_context, test_id="789")
|
||||
|
||||
# Assert - verify result contains file_ids
|
||||
result_json = json.loads(result)
|
||||
assert "file_ids" in result_json
|
||||
assert result_json["file_ids"] == ["12345678-1234-5678-9abc-123456789012"]
|
||||
|
||||
# Assert - verify IterationAnswer has correct file_ids
|
||||
assert len(fake_run_context.context.global_iteration_responses) == 1
|
||||
iteration_answer = fake_run_context.context.global_iteration_responses[0]
|
||||
assert iteration_answer.data is None
|
||||
assert iteration_answer.file_ids == ["12345678-1234-5678-9abc-123456789012"]
|
||||
assert iteration_answer.response_type == "csv"
|
||||
|
||||
# Verify file was saved
|
||||
mock_store_instance.save_file.assert_called_once()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MCP Tool Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@patch("onyx.tools.tool_implementations.mcp.mcp_tool.call_mcp_tool")
|
||||
def test_mcp_tool_run_v2_basic_invocation(
|
||||
mock_call_mcp_tool: MagicMock,
|
||||
mcp_tool: MCPTool,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
fake_db_session: Any,
|
||||
fake_redis_client: FakeRedis,
|
||||
) -> None:
|
||||
"""Test basic functionality of MCPTool.run_v2()."""
|
||||
# Arrange
|
||||
mock_call_mcp_tool.return_value = "MCP search results: test query"
|
||||
|
||||
fake_run_context = create_fake_run_context(
|
||||
chat_session_id, message_id, fake_db_session, fake_redis_client, mcp_tool
|
||||
)
|
||||
|
||||
# Act
|
||||
result = mcp_tool.run_v2(fake_run_context, query="test search")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
# The result is already a JSON string containing {"tool_result": ...}
|
||||
result_json = json.loads(result)
|
||||
# The tool_result is itself a JSON string that needs to be parsed
|
||||
inner_result = json.loads(result_json)
|
||||
assert "tool_result" in inner_result
|
||||
assert inner_result["tool_result"] == "MCP search results: test query"
|
||||
|
||||
# Verify MCP tool was called
|
||||
mock_call_mcp_tool.assert_called_once_with(
|
||||
mcp_tool.mcp_server.server_url,
|
||||
mcp_tool.name,
|
||||
{"query": "test search"},
|
||||
connection_headers={},
|
||||
transport=mcp_tool.mcp_server.transport,
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.tools.tool_implementations.mcp.mcp_tool.call_mcp_tool")
|
||||
def test_mcp_tool_run_v2_iteration_tracking(
|
||||
mock_call_mcp_tool: MagicMock,
|
||||
mcp_tool: MCPTool,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
fake_db_session: Any,
|
||||
fake_redis_client: FakeRedis,
|
||||
) -> None:
|
||||
"""Test that IterationInstructions and IterationAnswer are properly added."""
|
||||
# Arrange
|
||||
mock_call_mcp_tool.return_value = "MCP search results"
|
||||
|
||||
fake_run_context = create_fake_run_context(
|
||||
chat_session_id,
|
||||
message_id,
|
||||
fake_db_session,
|
||||
fake_redis_client,
|
||||
mcp_tool,
|
||||
current_run_step=1,
|
||||
)
|
||||
|
||||
# Act
|
||||
mcp_tool.run_v2(fake_run_context, query="test mcp search")
|
||||
|
||||
# Assert - verify IterationInstructions was added
|
||||
assert len(fake_run_context.context.iteration_instructions) == 1
|
||||
iteration_instruction = fake_run_context.context.iteration_instructions[0]
|
||||
assert isinstance(iteration_instruction, IterationInstructions)
|
||||
assert iteration_instruction.iteration_nr == 2
|
||||
assert iteration_instruction.plan == f"Running {mcp_tool.name}"
|
||||
assert iteration_instruction.purpose == f"Running {mcp_tool.name}"
|
||||
assert iteration_instruction.reasoning == f"Running {mcp_tool.name}"
|
||||
|
||||
# Assert - verify IterationAnswer was added
|
||||
assert len(fake_run_context.context.global_iteration_responses) == 1
|
||||
iteration_answer = fake_run_context.context.global_iteration_responses[0]
|
||||
assert isinstance(iteration_answer, IterationAnswer)
|
||||
assert iteration_answer.tool == mcp_tool.name
|
||||
assert iteration_answer.tool_id == mcp_tool.id
|
||||
assert iteration_answer.iteration_nr == 2
|
||||
assert iteration_answer.parallelization_nr == 0
|
||||
assert iteration_answer.question == '{"query": "test mcp search"}'
|
||||
assert iteration_answer.reasoning == f"Running {mcp_tool.name}"
|
||||
assert iteration_answer.cited_documents == {}
|
||||
|
||||
|
||||
@patch("onyx.tools.tool_implementations.mcp.mcp_tool.call_mcp_tool")
|
||||
def test_mcp_tool_run_v2_packet_emissions(
|
||||
mock_call_mcp_tool: MagicMock,
|
||||
mcp_tool: MCPTool,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
fake_db_session: Any,
|
||||
fake_redis_client: FakeRedis,
|
||||
) -> None:
|
||||
"""Test that the correct packets are emitted during MCP tool execution."""
|
||||
# Arrange
|
||||
mock_call_mcp_tool.return_value = "MCP result"
|
||||
|
||||
fake_run_context = create_fake_run_context(
|
||||
chat_session_id, message_id, fake_db_session, fake_redis_client, mcp_tool
|
||||
)
|
||||
|
||||
# Act
|
||||
mcp_tool.run_v2(fake_run_context, query="test")
|
||||
|
||||
# Assert - verify emitter captured packets
|
||||
emitter = fake_run_context.context.run_dependencies.emitter
|
||||
# Should have: CustomToolStart, CustomToolDelta
|
||||
assert len(emitter.packet_history) >= 2
|
||||
|
||||
# Check CustomToolStart
|
||||
start_packet = emitter.packet_history[0]
|
||||
assert getattr(start_packet.obj, "type", None) == "custom_tool_start"
|
||||
assert start_packet.obj.tool_name == mcp_tool.name
|
||||
|
||||
# Check CustomToolDelta
|
||||
delta_packet = emitter.packet_history[1]
|
||||
assert getattr(delta_packet.obj, "type", None) == "custom_tool_delta"
|
||||
assert delta_packet.obj.tool_name == mcp_tool.name
|
||||
@@ -0,0 +1,360 @@
|
||||
"""Tests for ImageGenerationTool.run_v2() using dependency injection.
|
||||
|
||||
This test module focuses on testing the ImageGenerationTool.run_v2() method directly,
|
||||
using dependency injection via creating fake implementations instead of using mocks.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from agents import RunContextWrapper
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from tests.unit.onyx.chat.turn.utils import FakeRedis
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fake Classes for Dependency Injection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_fake_database_session() -> Any:
|
||||
"""Create a fake SQLAlchemy Session for testing"""
|
||||
from unittest.mock import Mock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Create a mock that behaves like a real Session
|
||||
fake_session = Mock(spec=Session)
|
||||
fake_session.committed = False
|
||||
fake_session.rolled_back = False
|
||||
|
||||
def mock_commit() -> None:
|
||||
fake_session.committed = True
|
||||
|
||||
def mock_rollback() -> None:
|
||||
fake_session.rolled_back = True
|
||||
|
||||
fake_session.commit = mock_commit
|
||||
fake_session.rollback = mock_rollback
|
||||
fake_session.add = Mock()
|
||||
fake_session.flush = Mock()
|
||||
|
||||
return fake_session
|
||||
|
||||
|
||||
class FakeEmitter:
|
||||
"""Fake emitter for testing that records all emitted packets"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.packet_history: list[Any] = []
|
||||
|
||||
def emit(self, packet: Any) -> None:
|
||||
self.packet_history.append(packet)
|
||||
|
||||
|
||||
class FakeRunDependencies:
|
||||
"""Fake run dependencies for testing"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_session: Any,
|
||||
redis_client: FakeRedis,
|
||||
image_generation_tool: ImageGenerationTool,
|
||||
) -> None:
|
||||
self.db_session = db_session
|
||||
self.redis_client = redis_client
|
||||
self.emitter = FakeEmitter()
|
||||
self.tools = [image_generation_tool]
|
||||
|
||||
def get_prompt_config(self) -> PromptConfig:
|
||||
return PromptConfig(
|
||||
default_behavior_system_prompt="You are a helpful assistant.",
|
||||
reminder="Answer the user's question.",
|
||||
custom_instructions="",
|
||||
datetime_aware=False,
|
||||
)
|
||||
|
||||
|
||||
class FakeCancelledRedis(FakeRedis):
|
||||
"""Fake Redis client that always reports the session as cancelled."""
|
||||
|
||||
def exists(self, key: str) -> bool: # pragma: no cover - trivial override
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_fake_image_generation_tool(tool_id: int = 1) -> ImageGenerationTool:
|
||||
"""Create an ImageGenerationTool instance for testing"""
|
||||
return ImageGenerationTool(
|
||||
api_key="fake-api-key",
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
tool_id=tool_id,
|
||||
model="dall-e-3",
|
||||
num_imgs=1,
|
||||
)
|
||||
|
||||
|
||||
def create_fake_run_context(
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
db_session: Any,
|
||||
redis_client: FakeRedis,
|
||||
image_generation_tool: ImageGenerationTool,
|
||||
) -> RunContextWrapper[ChatTurnContext]:
|
||||
"""Create a fake run context for testing"""
|
||||
run_dependencies = FakeRunDependencies(
|
||||
db_session, redis_client, image_generation_tool
|
||||
)
|
||||
|
||||
context = ChatTurnContext(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=message_id,
|
||||
run_dependencies=run_dependencies, # type: ignore
|
||||
)
|
||||
|
||||
return RunContextWrapper(context=context)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_session_id() -> UUID:
|
||||
"""Fixture providing fake chat session ID."""
|
||||
return uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def message_id() -> int:
|
||||
"""Fixture providing fake message ID."""
|
||||
return 123
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def research_type() -> ResearchType:
|
||||
"""Fixture providing fake research type."""
|
||||
return ResearchType.FAST
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db_session() -> Any:
|
||||
"""Fixture providing a fake database session."""
|
||||
return create_fake_database_session()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_redis_client() -> FakeRedis:
|
||||
"""Fixture providing a fake Redis client."""
|
||||
return FakeRedis()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_generation_tool() -> ImageGenerationTool:
|
||||
"""Fixture providing an ImageGenerationTool with fake API credentials."""
|
||||
return create_fake_image_generation_tool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_run_context(
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
fake_db_session: Any,
|
||||
fake_redis_client: FakeRedis,
|
||||
image_generation_tool: ImageGenerationTool,
|
||||
) -> RunContextWrapper[ChatTurnContext]:
|
||||
"""Fixture providing a complete RunContextWrapper with fake implementations."""
|
||||
return create_fake_run_context(
|
||||
chat_session_id,
|
||||
message_id,
|
||||
fake_db_session,
|
||||
fake_redis_client,
|
||||
image_generation_tool,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_image_generation_tool_run_v2_basic_functionality(
|
||||
fake_run_context: RunContextWrapper[ChatTurnContext],
|
||||
image_generation_tool: ImageGenerationTool,
|
||||
) -> None:
|
||||
"""Test basic functionality of ImageGenerationTool.run_v2() using dependency injection.
|
||||
|
||||
This test verifies that the run_v2 method properly integrates with the v2 implementation.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
# Arrange
|
||||
prompt = "A beautiful sunset over mountains"
|
||||
shape = "landscape"
|
||||
|
||||
# Create fake generated images
|
||||
fake_generated_images = [
|
||||
GeneratedImage(
|
||||
file_id="file-123",
|
||||
url="https://example.com/files/file-123",
|
||||
revised_prompt="A stunning sunset over mountains with vibrant colors",
|
||||
)
|
||||
]
|
||||
|
||||
# Mock the core implementation
|
||||
with patch.object(image_generation_tool, "_image_generation_core") as mock_core:
|
||||
mock_core.return_value = fake_generated_images
|
||||
|
||||
# Act
|
||||
result = image_generation_tool.run_v2(
|
||||
fake_run_context, prompt=prompt, shape=shape
|
||||
)
|
||||
|
||||
# Assert - verify result is a success message
|
||||
assert isinstance(result, str)
|
||||
assert "Successfully generated 1 images" in result
|
||||
|
||||
# Verify the core was called with correct parameters
|
||||
mock_core.assert_called_once()
|
||||
call_args = mock_core.call_args
|
||||
# When patching a bound method, self is bound; first arg is run_context
|
||||
assert call_args[0][0] == fake_run_context # run_context
|
||||
assert call_args[0][1] == prompt # prompt
|
||||
assert call_args[0][2] == shape # shape
|
||||
|
||||
|
||||
def test_image_generation_tool_run_v2_missing_prompt(
|
||||
fake_run_context: RunContextWrapper[ChatTurnContext],
|
||||
image_generation_tool: ImageGenerationTool,
|
||||
) -> None:
|
||||
"""Test that run_v2 raises ValueError when prompt is missing."""
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
image_generation_tool.run_v2(fake_run_context)
|
||||
|
||||
assert "prompt is required" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_image_generation_tool_run_v2_default_shape(
|
||||
fake_run_context: RunContextWrapper[ChatTurnContext],
|
||||
image_generation_tool: ImageGenerationTool,
|
||||
) -> None:
|
||||
"""Test that run_v2 uses default shape when not provided."""
|
||||
from unittest.mock import patch
|
||||
|
||||
# Arrange
|
||||
prompt = "A cat playing with yarn"
|
||||
fake_generated_images = [
|
||||
GeneratedImage(
|
||||
file_id="file-456",
|
||||
url="https://example.com/files/file-456",
|
||||
revised_prompt="A playful cat playing with colorful yarn",
|
||||
)
|
||||
]
|
||||
|
||||
# Mock the core implementation
|
||||
with patch.object(image_generation_tool, "_image_generation_core") as mock_core:
|
||||
mock_core.return_value = fake_generated_images
|
||||
|
||||
# Act - don't provide shape parameter
|
||||
image_generation_tool.run_v2(fake_run_context, prompt=prompt)
|
||||
|
||||
# Assert - verify default shape was used
|
||||
call_args = mock_core.call_args
|
||||
assert call_args[0][2] == "square" # default shape
|
||||
|
||||
|
||||
def test_image_generation_tool_run_v2_multiple_images(
|
||||
fake_run_context: RunContextWrapper[ChatTurnContext],
|
||||
) -> None:
|
||||
"""Test that run_v2 handles multiple images correctly."""
|
||||
from unittest.mock import patch
|
||||
|
||||
# Arrange
|
||||
# Create tool that generates multiple images
|
||||
multi_image_tool = ImageGenerationTool(
|
||||
api_key="fake-api-key",
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
tool_id=1,
|
||||
model="dall-e-3",
|
||||
num_imgs=3,
|
||||
)
|
||||
|
||||
# Update run dependencies to include the multi-image tool
|
||||
fake_run_context.context.run_dependencies.tools = [multi_image_tool]
|
||||
|
||||
prompt = "A series of abstract patterns"
|
||||
fake_generated_images = [
|
||||
GeneratedImage(
|
||||
file_id=f"file-{i}",
|
||||
url=f"https://example.com/files/file-{i}",
|
||||
revised_prompt=f"Abstract pattern variation {i}",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Mock the core implementation
|
||||
with patch.object(multi_image_tool, "_image_generation_core") as mock_core:
|
||||
mock_core.return_value = fake_generated_images
|
||||
|
||||
# Act
|
||||
result = multi_image_tool.run_v2(fake_run_context, prompt=prompt)
|
||||
|
||||
# Assert
|
||||
assert "Successfully generated 3 images" in result
|
||||
|
||||
|
||||
def test_image_generation_tool_run_v2_handles_cancellation_gracefully(
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
fake_db_session: Any,
|
||||
image_generation_tool: ImageGenerationTool,
|
||||
) -> None:
|
||||
"""Test that run_v2 handles cancellation gracefully without calling external APIs."""
|
||||
from unittest.mock import patch
|
||||
|
||||
# Arrange - create a run context with a Redis client that always reports cancellation
|
||||
cancelled_run_context = create_fake_run_context(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=message_id,
|
||||
db_session=fake_db_session,
|
||||
redis_client=FakeCancelledRedis(),
|
||||
image_generation_tool=image_generation_tool,
|
||||
)
|
||||
|
||||
prompt = "A test image prompt that should be cancelled"
|
||||
|
||||
# Patch the tool's run method so it does NOT call the real image API.
|
||||
def fake_run(**kwargs: Any) -> Any:
|
||||
# Yield a single fake ToolResponse; it will be ignored because of cancellation.
|
||||
yield ToolResponse(id="ignored", response=None)
|
||||
|
||||
with patch.object(image_generation_tool, "run", side_effect=fake_run) as mock_run:
|
||||
# Act - this should not raise, and should not call external APIs
|
||||
result = image_generation_tool.run_v2(
|
||||
cancelled_run_context,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
# Assert - when cancelled gracefully, the tool should report zero generated images
|
||||
assert isinstance(result, str)
|
||||
assert "Successfully generated 0 images" in result
|
||||
|
||||
# Verify we invoked the patched run exactly once with the expected prompt
|
||||
mock_run.assert_called_once()
|
||||
350
backend/tests/unit/onyx/tools/test_open_url_tool_run_v2.py
Normal file
350
backend/tests/unit/onyx/tools/test_open_url_tool_run_v2.py
Normal file
@@ -0,0 +1,350 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from agents import RunContextWrapper
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContentProvider
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchProvider
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchResult
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.server.query_and_chat.streaming_models import FetchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_implementations.web_search.open_url_tool import OpenUrlTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_result_models import LlmOpenUrlResult
|
||||
from onyx.tools.tool_result_models import LlmWebSearchResult
|
||||
|
||||
|
||||
class MockTool:
|
||||
"""Mock Tool object for testing"""
|
||||
|
||||
def __init__(self, tool_id: int = 1, name: str = WebSearchTool.__name__):
|
||||
self.id = tool_id
|
||||
self.name = name
|
||||
|
||||
|
||||
class MockWebSearchProvider(WebSearchProvider):
|
||||
"""Mock implementation of WebSearchProvider for dependency injection"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
search_results: List[WebSearchResult] | None = None,
|
||||
should_raise_exception: bool = False,
|
||||
):
|
||||
self.search_results = search_results or []
|
||||
self.should_raise_exception = should_raise_exception
|
||||
|
||||
def search(self, query: str) -> List[WebSearchResult]:
|
||||
if self.should_raise_exception:
|
||||
raise Exception("Test exception from search provider")
|
||||
return self.search_results
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> List[WebContent]:
|
||||
return []
|
||||
|
||||
|
||||
class MockWebContentProvider(WebContentProvider):
|
||||
"""Mock implementation of WebContentProvider for dependency injection"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content_results: List[WebContent] | None = None,
|
||||
should_raise_exception: bool = False,
|
||||
):
|
||||
self.content_results = content_results or []
|
||||
self.should_raise_exception = should_raise_exception
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> List[WebContent]:
|
||||
if self.should_raise_exception:
|
||||
raise Exception("Test exception from content provider")
|
||||
return self.content_results
|
||||
|
||||
|
||||
class MockEmitter:
|
||||
"""Mock emitter for dependency injection"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.packet_history: list[Packet] = []
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.packet_history.append(packet)
|
||||
|
||||
|
||||
class MockRunDependencies:
|
||||
"""Mock run dependencies for dependency injection"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.emitter = MockEmitter()
|
||||
# Set up mock database session
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
self.db_session = MagicMock()
|
||||
# Configure the scalar method to return our mock tool
|
||||
mock_tool = MockTool()
|
||||
self.db_session.scalar.return_value = mock_tool
|
||||
|
||||
|
||||
def create_test_run_context(
|
||||
current_run_step: int = 0,
|
||||
iteration_instructions: List[IterationInstructions] | None = None,
|
||||
global_iteration_responses: List[IterationAnswer] | None = None,
|
||||
) -> RunContextWrapper[ChatTurnContext]:
|
||||
"""Create a real RunContextWrapper with test dependencies"""
|
||||
|
||||
# Create test dependencies
|
||||
emitter = MockEmitter()
|
||||
run_dependencies = MockRunDependencies()
|
||||
run_dependencies.emitter = emitter
|
||||
|
||||
# Create the actual context object
|
||||
context = ChatTurnContext(
|
||||
chat_session_id=uuid4(),
|
||||
message_id=1,
|
||||
current_run_step=current_run_step,
|
||||
iteration_instructions=iteration_instructions or [],
|
||||
global_iteration_responses=global_iteration_responses or [],
|
||||
run_dependencies=run_dependencies, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Create the run context wrapper
|
||||
run_context = RunContextWrapper(context=context)
|
||||
|
||||
return run_context
|
||||
|
||||
|
||||
def test_open_url_tool_run_v2_basic_functionality() -> None:
|
||||
"""Test basic functionality of OpenUrlTool.run_v2"""
|
||||
# Arrange
|
||||
test_run_context = create_test_run_context()
|
||||
urls = ["https://example.com/1", "https://example.com/2"]
|
||||
|
||||
# Create test content results
|
||||
test_content_results = [
|
||||
WebContent(
|
||||
title="Test Content 1",
|
||||
link="https://example.com/1",
|
||||
full_content="This is the full content of the first page",
|
||||
published_date=datetime(2024, 1, 1, 12, 0, 0),
|
||||
),
|
||||
WebContent(
|
||||
title="Test Content 2",
|
||||
link="https://example.com/2",
|
||||
full_content="This is the full content of the second page",
|
||||
published_date=None,
|
||||
),
|
||||
]
|
||||
|
||||
test_provider = MockWebContentProvider(content_results=test_content_results)
|
||||
|
||||
# Create tool instance
|
||||
open_url_tool = OpenUrlTool(tool_id=1)
|
||||
|
||||
# Mock the get_default_content_provider to return our test provider
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.web_search.open_url_tool.get_default_content_provider"
|
||||
) as mock_get_provider:
|
||||
mock_get_provider.return_value = test_provider
|
||||
|
||||
# Act
|
||||
result_json = open_url_tool.run_v2(test_run_context, urls=urls)
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
adapter = TypeAdapter(list[LlmOpenUrlResult])
|
||||
result = adapter.validate_json(result_json)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(r, LlmOpenUrlResult) for r in result)
|
||||
|
||||
# Check first result
|
||||
assert result[0].content == "This is the full content of the first page"
|
||||
assert result[0].document_citation_number == -1
|
||||
assert result[0].unique_identifier_to_strip_away == "https://example.com/1"
|
||||
|
||||
# Check second result
|
||||
assert result[1].content == "This is the full content of the second page"
|
||||
assert result[1].document_citation_number == -1
|
||||
assert result[1].unique_identifier_to_strip_away == "https://example.com/2"
|
||||
|
||||
# Check that fetched_documents_cache was populated
|
||||
assert len(test_run_context.context.fetched_documents_cache) == 2
|
||||
assert "https://example.com/1" in test_run_context.context.fetched_documents_cache
|
||||
assert "https://example.com/2" in test_run_context.context.fetched_documents_cache
|
||||
|
||||
# Verify cache entries have correct structure
|
||||
cache_entry_1 = test_run_context.context.fetched_documents_cache[
|
||||
"https://example.com/1"
|
||||
]
|
||||
assert cache_entry_1.document_citation_number == -1
|
||||
assert cache_entry_1.inference_section is not None
|
||||
|
||||
# Verify context was updated
|
||||
assert test_run_context.context.current_run_step == 2
|
||||
assert len(test_run_context.context.iteration_instructions) == 1
|
||||
assert len(test_run_context.context.global_iteration_responses) == 1
|
||||
|
||||
# Check iteration instruction
|
||||
instruction = test_run_context.context.iteration_instructions[0]
|
||||
assert isinstance(instruction, IterationInstructions)
|
||||
assert instruction.iteration_nr == 1
|
||||
assert instruction.purpose == "Fetching content from URLs"
|
||||
assert (
|
||||
"Web Fetch to gather information on https://example.com/1, https://example.com/2"
|
||||
in instruction.reasoning
|
||||
)
|
||||
|
||||
# Check iteration answer
|
||||
answer = test_run_context.context.global_iteration_responses[0]
|
||||
assert isinstance(answer, IterationAnswer)
|
||||
assert answer.tool == WebSearchTool.__name__
|
||||
assert answer.iteration_nr == 1
|
||||
assert (
|
||||
answer.question
|
||||
== "Fetch content from URLs: https://example.com/1, https://example.com/2"
|
||||
)
|
||||
assert len(answer.cited_documents) == 2
|
||||
|
||||
# Verify emitter events were captured
|
||||
emitter = cast(MockEmitter, test_run_context.context.run_dependencies.emitter)
|
||||
assert len(emitter.packet_history) == 2
|
||||
|
||||
# Check the types of emitted events
|
||||
assert isinstance(emitter.packet_history[0].obj, FetchToolStart)
|
||||
assert isinstance(emitter.packet_history[1].obj, SectionEnd)
|
||||
|
||||
# Verify the FetchToolStart event contains the correct SavedSearchDoc objects
|
||||
fetch_start_event = emitter.packet_history[0].obj
|
||||
assert len(fetch_start_event.documents) == 2
|
||||
assert fetch_start_event.documents[0].link == "https://example.com/1"
|
||||
assert fetch_start_event.documents[1].link == "https://example.com/2"
|
||||
assert fetch_start_event.documents[0].source_type == DocumentSource.WEB
|
||||
|
||||
|
||||
def test_open_url_tool_run_v2_exception_handling() -> None:
|
||||
"""Test that OpenUrlTool.run_v2 handles exceptions properly"""
|
||||
# Arrange
|
||||
test_run_context = create_test_run_context()
|
||||
urls = ["https://example.com/1", "https://example.com/2"]
|
||||
|
||||
# Create a provider that will raise an exception
|
||||
test_provider = MockWebContentProvider(should_raise_exception=True)
|
||||
|
||||
# Create tool instance
|
||||
open_url_tool = OpenUrlTool(tool_id=1)
|
||||
|
||||
# Mock the get_default_content_provider to return our test provider
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.web_search.open_url_tool.get_default_content_provider"
|
||||
) as mock_get_provider:
|
||||
mock_get_provider.return_value = test_provider
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Test exception from content provider"):
|
||||
open_url_tool.run_v2(test_run_context, urls=urls)
|
||||
|
||||
# Verify that even though an exception was raised, we still emitted the initial events
|
||||
# and the SectionEnd packet was emitted by the decorator
|
||||
emitter = test_run_context.context.run_dependencies.emitter # type: ignore[attr-defined]
|
||||
assert len(emitter.packet_history) == 2 # FetchToolStart and SectionEnd
|
||||
|
||||
# Check the types of emitted events
|
||||
assert isinstance(emitter.packet_history[0].obj, FetchToolStart)
|
||||
assert isinstance(emitter.packet_history[1].obj, SectionEnd)
|
||||
|
||||
# Verify that the decorator properly handled the exception and updated current_run_step
|
||||
assert (
|
||||
test_run_context.context.current_run_step == 2
|
||||
) # Should be 2 after proper handling
|
||||
|
||||
|
||||
def test_open_url_tool_run_v2_cache_deduplication() -> None:
|
||||
"""Test that WebSearchTool.run_v2 and OpenUrlTool.run_v2 share fetched_documents_cache for the same URL"""
|
||||
# Arrange
|
||||
test_run_context = create_test_run_context()
|
||||
test_url = "https://example.com/1"
|
||||
|
||||
# First, do a web search that returns this URL
|
||||
search_results = [
|
||||
WebSearchResult(
|
||||
title="Test Result",
|
||||
link=test_url,
|
||||
author="Test Author",
|
||||
published_date=datetime(2024, 1, 1, 12, 0, 0),
|
||||
snippet="This is a test snippet",
|
||||
),
|
||||
]
|
||||
|
||||
# Then, fetch the full content for the same URL
|
||||
content_results = [
|
||||
WebContent(
|
||||
title="Test Content",
|
||||
link=test_url,
|
||||
full_content="This is the full content of the page",
|
||||
published_date=datetime(2024, 1, 1, 12, 0, 0),
|
||||
),
|
||||
]
|
||||
|
||||
search_provider = MockWebSearchProvider(search_results=search_results)
|
||||
content_provider = MockWebContentProvider(content_results=content_results)
|
||||
|
||||
web_search_tool = WebSearchTool(tool_id=1)
|
||||
open_url_tool = OpenUrlTool(tool_id=1)
|
||||
|
||||
from unittest.mock import patch
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
# Act - first do web_search via run_v2
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.web_search.web_search_tool.get_default_provider"
|
||||
) as mock_get_provider:
|
||||
mock_get_provider.return_value = search_provider
|
||||
search_result_json = web_search_tool.run_v2(
|
||||
test_run_context, queries=["test query"]
|
||||
)
|
||||
|
||||
adapter_ws = TypeAdapter(list[LlmWebSearchResult])
|
||||
search_result = adapter_ws.validate_json(search_result_json)
|
||||
|
||||
# Verify search result
|
||||
assert len(search_result) == 1
|
||||
assert search_result[0].url == test_url
|
||||
|
||||
# Verify cache was populated by web_search
|
||||
assert test_url in test_run_context.context.fetched_documents_cache
|
||||
|
||||
# Now run open_url via run_v2
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.web_search.open_url_tool.get_default_content_provider"
|
||||
) as mock_get_content_provider:
|
||||
mock_get_content_provider.return_value = content_provider
|
||||
open_result_json = open_url_tool.run_v2(test_run_context, urls=[test_url])
|
||||
|
||||
adapter_ou = TypeAdapter(list[LlmOpenUrlResult])
|
||||
open_result = adapter_ou.validate_json(open_result_json)
|
||||
|
||||
# Verify open_url result
|
||||
assert len(open_result) == 1
|
||||
assert open_result[0].content == "This is the full content of the page"
|
||||
|
||||
# Verify cache still has the same entry (not duplicated)
|
||||
assert len(test_run_context.context.fetched_documents_cache) == 1
|
||||
assert test_url in test_run_context.context.fetched_documents_cache
|
||||
cache_entry_after_open = test_run_context.context.fetched_documents_cache[test_url]
|
||||
|
||||
# Verify that the cache entry was updated with the full content
|
||||
# (The inference section should be updated, not replaced)
|
||||
assert cache_entry_after_open.document_citation_number == -1
|
||||
447
backend/tests/unit/onyx/tools/test_search_tool_run_v2.py
Normal file
447
backend/tests/unit/onyx/tools/test_search_tool_run_v2.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Tests for SearchTool.run_v2() using dependency injection via search_pipeline_override_for_testing.
|
||||
|
||||
This test module focuses on testing the SearchTool.run_v2() method directly,
|
||||
using the search_pipeline_override_for_testing parameter for dependency injection
|
||||
instead of using mocks.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from agents import RunContextWrapper
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.server.query_and_chat.streaming_models import SavedSearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from tests.unit.onyx.chat.turn.utils import create_test_inference_chunk
|
||||
from tests.unit.onyx.chat.turn.utils import FakeQuery
|
||||
from tests.unit.onyx.chat.turn.utils import FakeRedis
|
||||
from tests.unit.onyx.chat.turn.utils import FakeResult
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fake Classes for Dependency Injection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_fake_database_session() -> Any:
|
||||
"""Create a fake SQLAlchemy Session for testing"""
|
||||
from unittest.mock import Mock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Create a mock that behaves like a real Session
|
||||
fake_session = Mock(spec=Session)
|
||||
fake_session.committed = False
|
||||
fake_session.rolled_back = False
|
||||
|
||||
def mock_commit() -> None:
|
||||
fake_session.committed = True
|
||||
|
||||
def mock_rollback() -> None:
|
||||
fake_session.rolled_back = True
|
||||
|
||||
fake_session.commit = mock_commit
|
||||
fake_session.rollback = mock_rollback
|
||||
fake_session.add = Mock()
|
||||
fake_session.flush = Mock()
|
||||
fake_session.query = Mock(return_value=FakeQuery())
|
||||
fake_session.execute = Mock(return_value=FakeResult())
|
||||
|
||||
return fake_session
|
||||
|
||||
|
||||
class FakeSearchQuery:
|
||||
"""Fake SearchQuery for testing"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.search_type = SearchType.SEMANTIC
|
||||
self.filters = IndexFilters(access_control_list=None)
|
||||
self.recency_bias_multiplier = 1.0
|
||||
|
||||
|
||||
class FakeSearchPipeline:
|
||||
"""Fake SearchPipeline for dependency injection in SearchTool"""
|
||||
|
||||
def __init__(self, sections: list[InferenceSection] | None = None) -> None:
|
||||
self.sections = sections or []
|
||||
self.search_query = FakeSearchQuery()
|
||||
|
||||
@property
|
||||
def merged_retrieved_sections(self) -> list[InferenceSection]:
|
||||
return self.sections
|
||||
|
||||
@property
|
||||
def final_context_sections(self) -> list[InferenceSection]:
|
||||
return self.sections
|
||||
|
||||
@property
|
||||
def section_relevance(self) -> list | None:
|
||||
return None
|
||||
|
||||
|
||||
class FakeLLMConfig:
|
||||
"""Fake LLM config for testing"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.max_input_tokens = 128000 # Default GPT-4 context
|
||||
self.model_name = "gpt-4"
|
||||
self.model_provider = "openai"
|
||||
|
||||
|
||||
class FakeLLM:
|
||||
"""Fake LLM for testing"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.config = FakeLLMConfig()
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class FakePersona:
|
||||
"""Fake Persona for testing"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.id = 1
|
||||
self.name = "Test Persona"
|
||||
self.chunks_above = None
|
||||
self.chunks_below = None
|
||||
self.llm_relevance_filter = False
|
||||
self.llm_filter_extraction = False
|
||||
self.recency_bias = "auto"
|
||||
self.prompt_ids = []
|
||||
self.document_sets = []
|
||||
self.num_chunks = None
|
||||
self.llm_model_version_override = None
|
||||
|
||||
|
||||
class FakeUser:
|
||||
"""Fake User for testing"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.id = 1
|
||||
self.email = "test@example.com"
|
||||
|
||||
|
||||
class FakeEmitter:
|
||||
"""Fake emitter for testing that records all emitted packets"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.packet_history: list[Any] = []
|
||||
|
||||
def emit(self, packet: Any) -> None:
|
||||
self.packet_history.append(packet)
|
||||
|
||||
|
||||
class FakeRunDependencies:
|
||||
"""Fake run dependencies for testing"""
|
||||
|
||||
def __init__(
|
||||
self, db_session: Any, redis_client: FakeRedis, search_tool: SearchTool
|
||||
) -> None:
|
||||
self.db_session = db_session
|
||||
self.redis_client = redis_client
|
||||
self.emitter = FakeEmitter()
|
||||
self.tools = [search_tool]
|
||||
|
||||
def get_prompt_config(self) -> PromptConfig:
|
||||
return PromptConfig(
|
||||
default_behavior_system_prompt="You are a helpful assistant.",
|
||||
reminder="Answer the user's question.",
|
||||
custom_instructions="",
|
||||
datetime_aware=False,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_search_section_with_semantic_id(
|
||||
document_id: str, semantic_identifier: str, content: str, link: str
|
||||
) -> InferenceSection:
|
||||
"""Create a test inference section with custom semantic_identifier"""
|
||||
chunk = create_test_inference_chunk(
|
||||
document_id=document_id,
|
||||
semantic_identifier=semantic_identifier,
|
||||
content=content,
|
||||
link=link,
|
||||
)
|
||||
return InferenceSection(
|
||||
center_chunk=chunk,
|
||||
chunks=[chunk],
|
||||
combined_content=content,
|
||||
)
|
||||
|
||||
|
||||
def create_fake_search_pipeline_with_results(
|
||||
sections: list[InferenceSection] | None = None,
|
||||
) -> FakeSearchPipeline:
|
||||
"""Create a fake search pipeline with test results"""
|
||||
if sections is None:
|
||||
sections = [
|
||||
create_search_section_with_semantic_id(
|
||||
document_id="doc1",
|
||||
semantic_identifier="test_doc_1",
|
||||
content="First test document content",
|
||||
link="https://example.com/doc1",
|
||||
),
|
||||
create_search_section_with_semantic_id(
|
||||
document_id="doc2",
|
||||
semantic_identifier="test_doc_2",
|
||||
content="Second test document content",
|
||||
link="https://example.com/doc2",
|
||||
),
|
||||
]
|
||||
|
||||
return FakeSearchPipeline(sections=sections)
|
||||
|
||||
|
||||
def create_search_tool_with_fake_pipeline(
|
||||
search_pipeline: FakeSearchPipeline,
|
||||
db_session: Any | None = None,
|
||||
tool_id: int = 1,
|
||||
) -> SearchTool:
|
||||
"""Create a SearchTool instance with a fake search pipeline for testing"""
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
|
||||
fake_db_session = db_session or create_fake_database_session()
|
||||
fake_llm = FakeLLM()
|
||||
fake_persona = FakePersona()
|
||||
fake_user = FakeUser()
|
||||
|
||||
return SearchTool(
|
||||
tool_id=tool_id,
|
||||
db_session=fake_db_session,
|
||||
user=fake_user, # type: ignore
|
||||
persona=fake_persona, # type: ignore
|
||||
retrieval_options=None,
|
||||
prompt_config=PromptConfig(
|
||||
default_behavior_system_prompt="You are a helpful assistant.",
|
||||
reminder="Answer the user's question.",
|
||||
custom_instructions="",
|
||||
datetime_aware=False,
|
||||
),
|
||||
llm=fake_llm, # type: ignore
|
||||
fast_llm=fake_llm, # type: ignore
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
answer_style_config=AnswerStyleConfig(citation_config=CitationConfig()),
|
||||
document_pruning_config=DocumentPruningConfig(),
|
||||
search_pipeline_override_for_testing=search_pipeline, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def create_fake_run_context(
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
db_session: Any,
|
||||
redis_client: FakeRedis,
|
||||
search_tool: SearchTool,
|
||||
) -> RunContextWrapper[ChatTurnContext]:
|
||||
"""Create a fake run context for testing"""
|
||||
run_dependencies = FakeRunDependencies(db_session, redis_client, search_tool)
|
||||
|
||||
context = ChatTurnContext(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=message_id,
|
||||
run_dependencies=run_dependencies, # type: ignore
|
||||
)
|
||||
|
||||
return RunContextWrapper(context=context)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_session_id() -> UUID:
|
||||
"""Fixture providing fake chat session ID."""
|
||||
return uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def message_id() -> int:
|
||||
"""Fixture providing fake message ID."""
|
||||
return 123
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def research_type() -> ResearchType:
|
||||
"""Fixture providing fake research type."""
|
||||
return ResearchType.FAST
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db_session() -> Any:
|
||||
"""Fixture providing a fake database session."""
|
||||
return create_fake_database_session()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_redis_client() -> FakeRedis:
|
||||
"""Fixture providing a fake Redis client."""
|
||||
return FakeRedis()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_search_pipeline() -> FakeSearchPipeline:
|
||||
"""Fixture providing a fake search pipeline with default test results."""
|
||||
return create_fake_search_pipeline_with_results()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_tool(
|
||||
fake_search_pipeline: FakeSearchPipeline, fake_db_session: Any
|
||||
) -> SearchTool:
|
||||
"""Fixture providing a SearchTool with fake search pipeline."""
|
||||
return create_search_tool_with_fake_pipeline(fake_search_pipeline, fake_db_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_run_context(
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
fake_db_session: Any,
|
||||
fake_redis_client: FakeRedis,
|
||||
search_tool: SearchTool,
|
||||
) -> RunContextWrapper[ChatTurnContext]:
|
||||
"""Fixture providing a complete RunContextWrapper with fake implementations."""
|
||||
return create_fake_run_context(
|
||||
chat_session_id, message_id, fake_db_session, fake_redis_client, search_tool
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_search_tool_run_v2_basic_functionality(
|
||||
fake_run_context: RunContextWrapper[ChatTurnContext],
|
||||
search_tool: SearchTool,
|
||||
fake_db_session: Any,
|
||||
) -> None:
|
||||
"""Test basic functionality of SearchTool.run_v2() using dependency injection.
|
||||
|
||||
This test mirrors the original test_internal_search_core_basic_functionality but
|
||||
uses the run_v2 method with search_pipeline_override_for_testing instead of mocks.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
# Arrange
|
||||
query = "test search query"
|
||||
|
||||
# Create a session context manager
|
||||
class FakeSessionContextManager:
|
||||
def __enter__(self) -> Any:
|
||||
return fake_db_session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
pass
|
||||
|
||||
# Act - patch get_session_with_current_tenant to return our fake session
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.get_session_with_current_tenant",
|
||||
return_value=FakeSessionContextManager(),
|
||||
):
|
||||
result = search_tool.run_v2(fake_run_context, query=query)
|
||||
|
||||
# Assert - verify result is a JSON string
|
||||
import json
|
||||
|
||||
result_list = json.loads(result)
|
||||
assert isinstance(result_list, list)
|
||||
assert len(result_list) == 2
|
||||
|
||||
# Verify result contains InternalSearchResult objects (as dicts in JSON)
|
||||
assert result_list[0]["unique_identifier_to_strip_away"] == "doc1"
|
||||
assert result_list[0]["title"] == "test_doc_1"
|
||||
assert result_list[0]["excerpt"] == "First test document content"
|
||||
assert result_list[0]["document_citation_number"] == -1
|
||||
|
||||
assert result_list[1]["unique_identifier_to_strip_away"] == "doc2"
|
||||
assert result_list[1]["title"] == "test_doc_2"
|
||||
assert result_list[1]["excerpt"] == "Second test document content"
|
||||
assert result_list[1]["document_citation_number"] == -1
|
||||
|
||||
# Verify context was updated (decorator increments current_run_step)
|
||||
assert fake_run_context.context.current_run_step == 2
|
||||
assert len(fake_run_context.context.iteration_instructions) == 1
|
||||
assert len(fake_run_context.context.global_iteration_responses) == 1
|
||||
|
||||
# Verify fetched_documents_cache was populated
|
||||
assert len(fake_run_context.context.fetched_documents_cache) == 2
|
||||
assert "doc1" in fake_run_context.context.fetched_documents_cache
|
||||
assert "doc2" in fake_run_context.context.fetched_documents_cache
|
||||
|
||||
# Verify cache entries have correct structure
|
||||
cache_entry_1 = fake_run_context.context.fetched_documents_cache["doc1"]
|
||||
assert cache_entry_1.document_citation_number == -1
|
||||
assert cache_entry_1.inference_section is not None
|
||||
assert cache_entry_1.inference_section.center_chunk.document_id == "doc1"
|
||||
|
||||
cache_entry_2 = fake_run_context.context.fetched_documents_cache["doc2"]
|
||||
assert cache_entry_2.document_citation_number == -1
|
||||
assert cache_entry_2.inference_section.center_chunk.document_id == "doc2"
|
||||
|
||||
# Check iteration instruction
|
||||
instruction = fake_run_context.context.iteration_instructions[0]
|
||||
assert isinstance(instruction, IterationInstructions)
|
||||
assert instruction.iteration_nr == 1
|
||||
assert instruction.purpose == "Searching internally for information"
|
||||
assert (
|
||||
"I am now using Internal Search to gather information on"
|
||||
in instruction.reasoning
|
||||
)
|
||||
|
||||
# Check iteration answer
|
||||
answer = fake_run_context.context.global_iteration_responses[0]
|
||||
assert isinstance(answer, IterationAnswer)
|
||||
assert answer.tool == SearchTool.__name__
|
||||
assert answer.tool_id == search_tool.id
|
||||
assert answer.iteration_nr == 1
|
||||
assert answer.answer == ""
|
||||
assert len(answer.cited_documents) == 2
|
||||
|
||||
# Verify emitter events were captured
|
||||
emitter = fake_run_context.context.run_dependencies.emitter
|
||||
# Should have: SearchToolStart, SearchToolDelta (query), SearchToolDelta (docs)
|
||||
assert len(emitter.packet_history) >= 3
|
||||
|
||||
# Check the types of emitted events
|
||||
assert isinstance(emitter.packet_history[0].obj, SearchToolStart)
|
||||
assert isinstance(emitter.packet_history[1].obj, SearchToolDelta)
|
||||
assert isinstance(emitter.packet_history[2].obj, SearchToolDelta)
|
||||
|
||||
# Check the first SearchToolDelta (query)
|
||||
first_delta = emitter.packet_history[1].obj
|
||||
assert first_delta.queries == [query]
|
||||
assert first_delta.documents == []
|
||||
|
||||
# Check the second SearchToolDelta (documents)
|
||||
second_delta = emitter.packet_history[2].obj
|
||||
assert second_delta.queries == []
|
||||
assert len(second_delta.documents) == 2
|
||||
|
||||
# Verify the SavedSearchDoc objects
|
||||
first_doc = second_delta.documents[0]
|
||||
assert isinstance(first_doc, SavedSearchDoc)
|
||||
assert first_doc.document_id == "doc1"
|
||||
assert first_doc.semantic_identifier == "test_doc_1"
|
||||
@@ -18,9 +18,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import (
|
||||
LlmInternalSearchResult,
|
||||
)
|
||||
from onyx.tools.tool_result_models import LlmInternalSearchResult
|
||||
from tests.unit.onyx.chat.turn.utils import create_test_inference_chunk
|
||||
from tests.unit.onyx.chat.turn.utils import create_test_inference_section
|
||||
from tests.unit.onyx.chat.turn.utils import FakeQuery
|
||||
|
||||
@@ -25,10 +25,10 @@ from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import LlmOpenUrlResult
|
||||
from onyx.tools.tool_implementations_v2.tool_result_models import LlmWebSearchResult
|
||||
from onyx.tools.tool_implementations_v2.web import _open_url_core
|
||||
from onyx.tools.tool_implementations_v2.web import _web_search_core
|
||||
from onyx.tools.tool_result_models import LlmOpenUrlResult
|
||||
from onyx.tools.tool_result_models import LlmWebSearchResult
|
||||
|
||||
|
||||
class MockTool:
|
||||
|
||||
398
backend/tests/unit/onyx/tools/test_web_search_tool_run_v2.py
Normal file
398
backend/tests/unit/onyx/tools/test_web_search_tool_run_v2.py
Normal file
@@ -0,0 +1,398 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from agents import RunContextWrapper
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchProvider
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchResult
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_result_models import LlmWebSearchResult
|
||||
|
||||
|
||||
class MockTool:
|
||||
"""Mock Tool object for testing"""
|
||||
|
||||
def __init__(self, tool_id: int = 1, name: str = WebSearchTool.__name__):
|
||||
self.id = tool_id
|
||||
self.name = name
|
||||
|
||||
|
||||
class MockWebSearchProvider(WebSearchProvider):
|
||||
"""Mock implementation of WebSearchProvider for dependency injection"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
search_results: List[WebSearchResult] | None = None,
|
||||
should_raise_exception: bool = False,
|
||||
):
|
||||
self.search_results = search_results or []
|
||||
self.should_raise_exception = should_raise_exception
|
||||
|
||||
def search(self, query: str) -> List[WebSearchResult]:
|
||||
if self.should_raise_exception:
|
||||
raise Exception("Test exception from search provider")
|
||||
return self.search_results
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> List[WebContent]:
|
||||
return []
|
||||
|
||||
|
||||
class MockEmitter:
|
||||
"""Mock emitter for dependency injection"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.packet_history: list[Packet] = []
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.packet_history.append(packet)
|
||||
|
||||
|
||||
class MockRunDependencies:
|
||||
"""Mock run dependencies for dependency injection"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.emitter = MockEmitter()
|
||||
# Set up mock database session
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
self.db_session = MagicMock()
|
||||
# Configure the scalar method to return our mock tool
|
||||
mock_tool = MockTool()
|
||||
self.db_session.scalar.return_value = mock_tool
|
||||
|
||||
|
||||
def create_test_run_context(
|
||||
current_run_step: int = 0,
|
||||
iteration_instructions: List[IterationInstructions] | None = None,
|
||||
global_iteration_responses: List[IterationAnswer] | None = None,
|
||||
) -> RunContextWrapper[ChatTurnContext]:
|
||||
"""Create a real RunContextWrapper with test dependencies"""
|
||||
|
||||
# Create test dependencies
|
||||
emitter = MockEmitter()
|
||||
run_dependencies = MockRunDependencies()
|
||||
run_dependencies.emitter = emitter
|
||||
|
||||
# Create the actual context object
|
||||
context = ChatTurnContext(
|
||||
chat_session_id=uuid4(),
|
||||
message_id=1,
|
||||
current_run_step=current_run_step,
|
||||
iteration_instructions=iteration_instructions or [],
|
||||
global_iteration_responses=global_iteration_responses or [],
|
||||
run_dependencies=run_dependencies, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Create the run context wrapper
|
||||
run_context = RunContextWrapper(context=context)
|
||||
|
||||
return run_context
|
||||
|
||||
|
||||
def test_web_search_tool_run_v2_basic_functionality() -> None:
|
||||
"""Test basic functionality of WebSearchTool.run_v2 with a single query"""
|
||||
# Arrange
|
||||
test_run_context = create_test_run_context()
|
||||
queries = ["test search query"]
|
||||
|
||||
# Create test search results
|
||||
test_search_results = [
|
||||
WebSearchResult(
|
||||
title="Test Result 1",
|
||||
link="https://example.com/1",
|
||||
author="Test Author",
|
||||
published_date=datetime(2024, 1, 1, 12, 0, 0),
|
||||
snippet="This is a test snippet 1",
|
||||
),
|
||||
WebSearchResult(
|
||||
title="Test Result 2",
|
||||
link="https://example.com/2",
|
||||
author=None,
|
||||
published_date=None,
|
||||
snippet="This is a test snippet 2",
|
||||
),
|
||||
]
|
||||
|
||||
test_provider = MockWebSearchProvider(search_results=test_search_results)
|
||||
|
||||
# Create tool instance
|
||||
web_search_tool = WebSearchTool(tool_id=1)
|
||||
|
||||
# Mock the get_default_provider to return our test provider
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.web_search.web_search_tool.get_default_provider"
|
||||
) as mock_get_provider:
|
||||
mock_get_provider.return_value = test_provider
|
||||
|
||||
# Act
|
||||
result_json = web_search_tool.run_v2(test_run_context, queries=queries)
|
||||
|
||||
# Parse the JSON result
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
adapter = TypeAdapter(list[LlmWebSearchResult])
|
||||
result = adapter.validate_json(result_json)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(r, LlmWebSearchResult) for r in result)
|
||||
|
||||
# Check first result
|
||||
assert result[0].title == "Test Result 1"
|
||||
assert result[0].url == "https://example.com/1"
|
||||
assert result[0].snippet == "This is a test snippet 1"
|
||||
assert result[0].document_citation_number == -1
|
||||
assert result[0].unique_identifier_to_strip_away == "https://example.com/1"
|
||||
|
||||
# Check second result
|
||||
assert result[1].title == "Test Result 2"
|
||||
assert result[1].url == "https://example.com/2"
|
||||
assert result[1].snippet == "This is a test snippet 2"
|
||||
assert result[1].document_citation_number == -1
|
||||
assert result[1].unique_identifier_to_strip_away == "https://example.com/2"
|
||||
|
||||
# Check that fetched_documents_cache was populated
|
||||
assert len(test_run_context.context.fetched_documents_cache) == 2
|
||||
assert "https://example.com/1" in test_run_context.context.fetched_documents_cache
|
||||
assert "https://example.com/2" in test_run_context.context.fetched_documents_cache
|
||||
|
||||
# Verify cache entries have correct structure
|
||||
cache_entry_1 = test_run_context.context.fetched_documents_cache[
|
||||
"https://example.com/1"
|
||||
]
|
||||
assert cache_entry_1.document_citation_number == -1
|
||||
assert cache_entry_1.inference_section is not None
|
||||
|
||||
# Verify context was updated
|
||||
assert test_run_context.context.current_run_step == 2
|
||||
assert len(test_run_context.context.iteration_instructions) == 1
|
||||
assert len(test_run_context.context.global_iteration_responses) == 1
|
||||
|
||||
# Check iteration instruction
|
||||
instruction = test_run_context.context.iteration_instructions[0]
|
||||
assert isinstance(instruction, IterationInstructions)
|
||||
assert instruction.iteration_nr == 1
|
||||
assert instruction.purpose == "Searching the web for information"
|
||||
assert (
|
||||
"Web Search to gather information on test search query" in instruction.reasoning
|
||||
)
|
||||
|
||||
# Check iteration answer
|
||||
answer = test_run_context.context.global_iteration_responses[0]
|
||||
assert isinstance(answer, IterationAnswer)
|
||||
assert answer.tool == WebSearchTool.__name__
|
||||
assert answer.iteration_nr == 1
|
||||
assert answer.question == queries[0]
|
||||
|
||||
# Verify emitter events were captured
|
||||
emitter = cast(MockEmitter, test_run_context.context.run_dependencies.emitter)
|
||||
assert len(emitter.packet_history) == 4
|
||||
|
||||
# Check the types of emitted events
|
||||
assert isinstance(emitter.packet_history[0].obj, SearchToolStart)
|
||||
assert isinstance(emitter.packet_history[1].obj, SearchToolDelta)
|
||||
assert isinstance(emitter.packet_history[2].obj, SearchToolDelta)
|
||||
assert isinstance(emitter.packet_history[3].obj, SectionEnd)
|
||||
|
||||
# Verify first SearchToolDelta has queries but empty documents
|
||||
first_search_delta = cast(SearchToolDelta, emitter.packet_history[1].obj)
|
||||
assert first_search_delta.queries == queries
|
||||
assert first_search_delta.documents == []
|
||||
|
||||
# Verify second SearchToolDelta contains SavedSearchDoc objects for favicon display
|
||||
search_tool_delta = cast(SearchToolDelta, emitter.packet_history[2].obj)
|
||||
assert len(search_tool_delta.documents) == 2
|
||||
|
||||
# Verify documents have correct properties for frontend favicon display
|
||||
doc1 = search_tool_delta.documents[0]
|
||||
assert isinstance(doc1, SavedSearchDoc)
|
||||
assert doc1.link == "https://example.com/1"
|
||||
assert (
|
||||
doc1.semantic_identifier == "https://example.com/1"
|
||||
) # semantic_identifier is the link for web results
|
||||
assert doc1.blurb == "Test Result 1" # title is stored in blurb
|
||||
assert doc1.source_type == DocumentSource.WEB
|
||||
assert doc1.is_internet is True
|
||||
|
||||
|
||||
def test_web_search_tool_run_v2_exception_handling() -> None:
|
||||
"""Test that WebSearchTool.run_v2 handles exceptions properly"""
|
||||
# Arrange
|
||||
test_run_context = create_test_run_context()
|
||||
queries = ["test search query"]
|
||||
|
||||
# Create a provider that will raise an exception
|
||||
test_provider = MockWebSearchProvider(should_raise_exception=True)
|
||||
|
||||
# Create tool instance
|
||||
web_search_tool = WebSearchTool(tool_id=1)
|
||||
|
||||
# Mock the get_default_provider to return our test provider
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.web_search.web_search_tool.get_default_provider"
|
||||
) as mock_get_provider:
|
||||
mock_get_provider.return_value = test_provider
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Test exception from search provider"):
|
||||
web_search_tool.run_v2(test_run_context, queries=queries)
|
||||
|
||||
# Verify that even though an exception was raised, we still emitted the initial events
|
||||
# and the SectionEnd packet was emitted by the decorator
|
||||
emitter = test_run_context.context.run_dependencies.emitter # type: ignore[attr-defined]
|
||||
assert (
|
||||
len(emitter.packet_history) == 3
|
||||
) # SearchToolStart, first SearchToolDelta, and SectionEnd
|
||||
|
||||
# Check the types of emitted events
|
||||
assert isinstance(emitter.packet_history[0].obj, SearchToolStart)
|
||||
assert isinstance(emitter.packet_history[1].obj, SearchToolDelta)
|
||||
assert isinstance(emitter.packet_history[2].obj, SectionEnd)
|
||||
|
||||
# Verify that the decorator properly handled the exception and updated current_run_step
|
||||
assert (
|
||||
test_run_context.context.current_run_step == 2
|
||||
) # Should be 2 after proper handling
|
||||
|
||||
|
||||
def test_web_search_tool_run_v2_multiple_queries() -> None:
|
||||
"""Test WebSearchTool.run_v2 with multiple queries searched in parallel"""
|
||||
# Arrange
|
||||
test_run_context = create_test_run_context()
|
||||
queries = ["first query", "second query"]
|
||||
|
||||
# Create a mock provider that returns different results based on the query
|
||||
class MultiQueryMockProvider(WebSearchProvider):
|
||||
def search(self, query: str) -> List[WebSearchResult]:
|
||||
if query == "first query":
|
||||
return [
|
||||
WebSearchResult(
|
||||
title="First Result 1",
|
||||
link="https://example.com/first1",
|
||||
author="Author 1",
|
||||
published_date=datetime(2024, 1, 1, 12, 0, 0),
|
||||
snippet="Snippet for first query result 1",
|
||||
),
|
||||
WebSearchResult(
|
||||
title="First Result 2",
|
||||
link="https://example.com/first2",
|
||||
author=None,
|
||||
published_date=None,
|
||||
snippet="Snippet for first query result 2",
|
||||
),
|
||||
]
|
||||
elif query == "second query":
|
||||
return [
|
||||
WebSearchResult(
|
||||
title="Second Result 1",
|
||||
link="https://example.com/second1",
|
||||
author="Author 2",
|
||||
published_date=datetime(2024, 2, 1, 12, 0, 0),
|
||||
snippet="Snippet for second query result 1",
|
||||
),
|
||||
]
|
||||
return []
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> List[WebContent]:
|
||||
return []
|
||||
|
||||
test_provider = MultiQueryMockProvider()
|
||||
|
||||
# Create tool instance
|
||||
web_search_tool = WebSearchTool(tool_id=1)
|
||||
|
||||
# Mock the get_default_provider to return our test provider
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.web_search.web_search_tool.get_default_provider"
|
||||
) as mock_get_provider:
|
||||
mock_get_provider.return_value = test_provider
|
||||
|
||||
# Act
|
||||
result_json = web_search_tool.run_v2(test_run_context, queries=queries)
|
||||
|
||||
# Parse the JSON result
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
adapter = TypeAdapter(list[LlmWebSearchResult])
|
||||
result = adapter.validate_json(result_json)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, list)
|
||||
# Should have 3 total results (2 from first query + 1 from second query)
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(r, LlmWebSearchResult) for r in result)
|
||||
|
||||
# Verify all results are present (order may vary due to parallel execution)
|
||||
titles = {r.title for r in result}
|
||||
assert "First Result 1" in titles
|
||||
assert "First Result 2" in titles
|
||||
assert "Second Result 1" in titles
|
||||
|
||||
# Check that fetched_documents_cache was populated with all URLs
|
||||
assert len(test_run_context.context.fetched_documents_cache) == 3
|
||||
|
||||
# Verify context was updated
|
||||
assert test_run_context.context.current_run_step == 2
|
||||
assert len(test_run_context.context.iteration_instructions) == 1
|
||||
assert len(test_run_context.context.global_iteration_responses) == 1
|
||||
|
||||
# Check iteration instruction contains both queries
|
||||
instruction = test_run_context.context.iteration_instructions[0]
|
||||
assert isinstance(instruction, IterationInstructions)
|
||||
assert instruction.iteration_nr == 1
|
||||
assert instruction.purpose == "Searching the web for information"
|
||||
assert "first query" in instruction.reasoning
|
||||
assert "second query" in instruction.reasoning
|
||||
|
||||
# Check iteration answer
|
||||
answer = test_run_context.context.global_iteration_responses[0]
|
||||
assert isinstance(answer, IterationAnswer)
|
||||
assert answer.tool == WebSearchTool.__name__
|
||||
assert answer.iteration_nr == 1
|
||||
assert "first query" in answer.question
|
||||
assert "second query" in answer.question
|
||||
|
||||
# Verify emitter events were captured
|
||||
emitter = cast(MockEmitter, test_run_context.context.run_dependencies.emitter)
|
||||
assert len(emitter.packet_history) == 4
|
||||
|
||||
# Check the types of emitted events
|
||||
assert isinstance(emitter.packet_history[0].obj, SearchToolStart)
|
||||
assert isinstance(emitter.packet_history[1].obj, SearchToolDelta)
|
||||
assert isinstance(emitter.packet_history[2].obj, SearchToolDelta)
|
||||
assert isinstance(emitter.packet_history[3].obj, SectionEnd)
|
||||
|
||||
# Check that first SearchToolDelta contains both queries (with empty documents)
|
||||
first_search_delta = cast(SearchToolDelta, emitter.packet_history[1].obj)
|
||||
assert first_search_delta.queries is not None
|
||||
assert len(first_search_delta.queries) == 2
|
||||
assert "first query" in first_search_delta.queries
|
||||
assert "second query" in first_search_delta.queries
|
||||
assert first_search_delta.documents == []
|
||||
|
||||
# Check that second SearchToolDelta contains documents
|
||||
second_search_delta = cast(SearchToolDelta, emitter.packet_history[2].obj)
|
||||
assert (
|
||||
len(second_search_delta.documents) == 3
|
||||
) # 2 from first query + 1 from second query
|
||||
Reference in New Issue
Block a user