mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-26 12:15:48 +00:00
Compare commits
1 Commits
sharepoint
...
add-code-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68a484ae73 |
73
backend/alembic/versions/1c3f8a7b5d4e_add_python_tool.py
Normal file
73
backend/alembic/versions/1c3f8a7b5d4e_add_python_tool.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""add_python_tool
|
||||
|
||||
Revision ID: 1c3f8a7b5d4e
|
||||
Revises: 505c488f6662
|
||||
Create Date: 2025-02-14 00:00:00
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1c3f8a7b5d4e"
|
||||
down_revision = "505c488f6662"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
PYTHON_TOOL = {
|
||||
"name": "PythonTool",
|
||||
"display_name": "Code Interpreter",
|
||||
"description": (
|
||||
"The Code Interpreter Action lets assistants execute Python in an isolated runtime. "
|
||||
"It can process staged files, read and write artifacts, stream stdout and stderr, "
|
||||
"and return generated outputs for the chat session."
|
||||
),
|
||||
"in_code_tool_id": "PythonTool",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
try:
|
||||
existing = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
PYTHON_TOOL,
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
PYTHON_TOOL,
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
PYTHON_TOOL,
|
||||
)
|
||||
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Do not delete the tool entry on downgrade; leaving it is safe and keeps migrations idempotent.
|
||||
pass
|
||||
@@ -39,6 +39,7 @@ def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
DRPath.WEB_SEARCH,
|
||||
DRPath.KNOWLEDGE_GRAPH,
|
||||
DRPath.IMAGE_GENERATION,
|
||||
DRPath.PYTHON_TOOL,
|
||||
)
|
||||
and len(state.query_list) == 0
|
||||
):
|
||||
|
||||
@@ -21,6 +21,7 @@ AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
DRPath.WEB_SEARCH: 1.5,
|
||||
DRPath.IMAGE_GENERATION: 3.0,
|
||||
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
DRPath.PYTHON_TOOL: 2.0,
|
||||
DRPath.CLOSER: 0.0,
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class DRPath(str, Enum):
|
||||
WEB_SEARCH = "Web Search"
|
||||
IMAGE_GENERATION = "Image Generation"
|
||||
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
|
||||
PYTHON_TOOL = "Python"
|
||||
CLOSER = "Closer"
|
||||
LOGGER = "Logger"
|
||||
END = "End"
|
||||
|
||||
@@ -26,6 +26,9 @@ from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
dr_kg_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_graph_builder import (
|
||||
dr_python_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
|
||||
dr_ws_graph_builder,
|
||||
)
|
||||
@@ -58,12 +61,15 @@ def dr_graph_builder() -> StateGraph:
|
||||
image_generation_graph = dr_image_generation_graph_builder().compile()
|
||||
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
|
||||
|
||||
python_tool_graph = dr_python_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.PYTHON_TOOL, python_tool_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
graph.add_node(DRPath.CLOSER, closer)
|
||||
graph.add_node(DRPath.LOGGER, logging)
|
||||
|
||||
@@ -81,6 +87,7 @@ def dr_graph_builder() -> StateGraph:
|
||||
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.PYTHON_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
|
||||
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
|
||||
|
||||
@@ -69,6 +69,7 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
@@ -134,6 +135,9 @@ def _get_available_tools(
|
||||
continue
|
||||
llm_path = DRPath.KNOWLEDGE_GRAPH.value
|
||||
path = DRPath.KNOWLEDGE_GRAPH
|
||||
elif isinstance(tool, PythonTool):
|
||||
llm_path = DRPath.PYTHON_TOOL.value
|
||||
path = DRPath.PYTHON_TOOL
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
llm_path = DRPath.IMAGE_GENERATION.value
|
||||
path = DRPath.IMAGE_GENERATION
|
||||
@@ -778,6 +782,6 @@ def clarifier(
|
||||
active_source_types_descriptions="\n".join(active_source_types_descriptions),
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
uploaded_test_context=uploaded_text_context,
|
||||
uploaded_text_context=uploaded_text_context,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
)
|
||||
|
||||
@@ -140,7 +140,7 @@ def orchestrator(
|
||||
|
||||
available_tools = state.available_tools or {}
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
uploaded_context = state.uploaded_text_context or ""
|
||||
uploaded_image_context = state.uploaded_image_context or []
|
||||
|
||||
questions = [
|
||||
|
||||
@@ -228,7 +228,7 @@ def closer(
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
uploaded_context = state.uploaded_text_context or ""
|
||||
|
||||
clarification = state.clarification
|
||||
prompt_question = get_prompt_question(base_question, clarification)
|
||||
|
||||
@@ -46,7 +46,7 @@ class OrchestrationSetup(OrchestrationUpdate):
|
||||
active_source_types_descriptions: str | None = None
|
||||
assistant_system_prompt: str | None = None
|
||||
assistant_task_prompt: str | None = None
|
||||
uploaded_test_context: str | None = None
|
||||
uploaded_text_context: str | None = None
|
||||
uploaded_image_context: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Python Tool sub-agent for deep research."""
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def python_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""Log the beginning of a Python Tool branch."""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(
|
||||
f"Python Tool branch start for iteration {iteration_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="python_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,257 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import PYTHON_TOOL_USE_RESPONSE_PROMPT
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonToolResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _serialize_chat_files(chat_files: list[InMemoryChatFile]) -> list[dict[str, Any]]:
|
||||
serialized_files: list[dict[str, Any]] = []
|
||||
for chat_file in chat_files:
|
||||
file_payload: dict[str, Any] = {
|
||||
"id": str(chat_file.file_id),
|
||||
"name": chat_file.filename,
|
||||
"type": chat_file.file_type.value,
|
||||
}
|
||||
if chat_file.file_type == ChatFileType.IMAGE:
|
||||
file_payload["content"] = chat_file.to_base64()
|
||||
file_payload["is_base64"] = True
|
||||
elif chat_file.file_type.is_text_file():
|
||||
file_payload["content"] = chat_file.content.decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
file_payload["is_base64"] = False
|
||||
else:
|
||||
file_payload["content"] = base64.b64encode(chat_file.content).decode(
|
||||
"utf-8"
|
||||
)
|
||||
file_payload["is_base64"] = True
|
||||
serialized_files.append(file_payload)
|
||||
|
||||
return serialized_files
|
||||
|
||||
|
||||
def python_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""Execute the Python Tool with any files supplied by the user."""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
tool_key = state.tools_used[-1]
|
||||
python_tool_info = state.available_tools[tool_key]
|
||||
python_tool = cast(PythonTool | None, python_tool_info.tool_object)
|
||||
|
||||
if python_tool is None:
|
||||
raise ValueError("python_tool is not set")
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
files = graph_config.inputs.files
|
||||
|
||||
logger.debug(
|
||||
"Tool call start for %s %s.%s at %s",
|
||||
python_tool.llm_name,
|
||||
iteration_nr,
|
||||
parallelization_nr,
|
||||
datetime.now(),
|
||||
)
|
||||
|
||||
tool_args: dict[str, Any] | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_description=python_tool_info.description,
|
||||
)
|
||||
|
||||
content_with_files = build_content_with_imgs(
|
||||
message=tool_use_prompt,
|
||||
files=files,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
tool_prompt_message: dict[str, Any] = {
|
||||
"role": "user",
|
||||
"content": content_with_files,
|
||||
}
|
||||
if files:
|
||||
tool_prompt_message["files"] = _serialize_chat_files(files)
|
||||
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke(
|
||||
[tool_prompt_message],
|
||||
tools=[python_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
|
||||
if isinstance(tool_calling_msg, AIMessage) and tool_calling_msg.tool_calls:
|
||||
tool_args = tool_calling_msg.tool_calls[0].get("args")
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call for Python Tool")
|
||||
|
||||
if tool_args is None:
|
||||
tool_args = python_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
|
||||
if "files" in tool_args:
|
||||
tool_args = {key: value for key, value in tool_args.items() if key != "files"}
|
||||
|
||||
override_kwargs = {"files": files or []}
|
||||
|
||||
tool_responses = list(python_tool.run(override_kwargs=override_kwargs, **tool_args))
|
||||
|
||||
python_tool_result: PythonToolResult | None = None
|
||||
for response in tool_responses:
|
||||
if isinstance(response.response, PythonToolResult):
|
||||
python_tool_result = response.response
|
||||
break
|
||||
|
||||
if python_tool_result is None:
|
||||
raise ValueError("Python tool did not return a valid result")
|
||||
|
||||
final_result = python_tool.final_result(*tool_responses)
|
||||
tool_result_str = json.dumps(final_result, ensure_ascii=False)
|
||||
|
||||
tool_summary_prompt = PYTHON_TOOL_USE_RESPONSE_PROMPT.build(
|
||||
base_question=base_question,
|
||||
tool_response=tool_result_str,
|
||||
)
|
||||
|
||||
initial_files = list(files or [])
|
||||
generated_files: list[InMemoryChatFile] = []
|
||||
for artifact in python_tool_result.artifacts:
|
||||
if not artifact.file_id:
|
||||
continue
|
||||
|
||||
chat_file = python_tool._available_files.get(artifact.file_id)
|
||||
if not chat_file:
|
||||
logger.warning(
|
||||
"Generated artifact with id %s not found in available files",
|
||||
artifact.file_id,
|
||||
)
|
||||
continue
|
||||
|
||||
filename = (
|
||||
chat_file.filename
|
||||
or artifact.display_name
|
||||
or artifact.path
|
||||
or str(artifact.file_id)
|
||||
)
|
||||
filename = Path(filename).name or str(artifact.file_id)
|
||||
if not filename.startswith("generated_"):
|
||||
filename = f"generated_{filename}"
|
||||
|
||||
generated_files.append(
|
||||
InMemoryChatFile(
|
||||
file_id=chat_file.file_id,
|
||||
content=chat_file.content,
|
||||
file_type=chat_file.file_type,
|
||||
filename=filename,
|
||||
)
|
||||
)
|
||||
|
||||
summary_files = initial_files + generated_files
|
||||
summary_content = build_content_with_imgs(
|
||||
message=tool_summary_prompt,
|
||||
files=summary_files,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
summary_message: dict[str, Any] = {
|
||||
"role": "user",
|
||||
"content": summary_content,
|
||||
}
|
||||
if summary_files:
|
||||
summary_message["files"] = _serialize_chat_files(summary_files)
|
||||
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke(
|
||||
[summary_message],
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
).content
|
||||
).strip()
|
||||
|
||||
artifact_file_ids = [
|
||||
artifact.file_id
|
||||
for artifact in python_tool_result.artifacts
|
||||
if artifact.file_id
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
"Tool call end for %s %s.%s at %s",
|
||||
python_tool.llm_name,
|
||||
iteration_nr,
|
||||
parallelization_nr,
|
||||
datetime.now(),
|
||||
)
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=python_tool.llm_name,
|
||||
tool_id=python_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
response_type="json",
|
||||
data=final_result,
|
||||
file_ids=artifact_file_ids or None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="python_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,76 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def python_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""Stream the Python Tool result back to the client."""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
current_step_nr = state.current_step_nr
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
for new_update in new_updates:
|
||||
if not new_update.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolStart(
|
||||
tool_name=new_update.tool,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolDelta(
|
||||
tool_name=new_update.tool,
|
||||
response_type=new_update.response_type,
|
||||
data=new_update.data,
|
||||
file_ids=new_update.file_ids,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="python_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
"""Forward the current query to the Python Tool executor."""
|
||||
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(state.query_list[:1])
|
||||
]
|
||||
@@ -0,0 +1,38 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_1_branch import (
|
||||
python_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_2_act import (
|
||||
python_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_3_reduce import (
|
||||
python_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_python_tool_graph_builder() -> StateGraph:
|
||||
"""LangGraph graph builder for the Python Tool sub-agent."""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
graph.add_node("branch", python_tool_branch)
|
||||
graph.add_node("act", python_tool_act)
|
||||
graph.add_node("reducer", python_tool_reducer)
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -758,6 +758,24 @@ AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")
|
||||
# configurable image model
|
||||
IMAGE_MODEL_NAME = os.environ.get("IMAGE_MODEL_NAME", "gpt-image-1")
|
||||
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
|
||||
_CODE_INTERPRETER_DEFAULT_TIMEOUT_MS_RAW = os.environ.get(
|
||||
"CODE_INTERPRETER_DEFAULT_TIMEOUT_MS"
|
||||
)
|
||||
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = (
|
||||
int(_CODE_INTERPRETER_DEFAULT_TIMEOUT_MS_RAW)
|
||||
if _CODE_INTERPRETER_DEFAULT_TIMEOUT_MS_RAW
|
||||
else 30_000
|
||||
)
|
||||
_CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS_RAW = os.environ.get(
|
||||
"CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS"
|
||||
)
|
||||
CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS = (
|
||||
int(_CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS_RAW)
|
||||
if _CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS_RAW
|
||||
else 30
|
||||
)
|
||||
|
||||
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
|
||||
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -53,6 +53,11 @@ def log_prompt(prompt: LanguageModelInput) -> None:
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
elif isinstance(msg, dict):
|
||||
log_msg = msg.get("content", "")
|
||||
if "files" in msg:
|
||||
log_msg = f"{log_msg}\nfiles: {msg['files']}"
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
|
||||
@@ -97,6 +97,15 @@ some sub-answers, but the question sent to the graph must be a question that can
|
||||
entity and relationship types and their attributes that will be communicated to you.
|
||||
"""
|
||||
|
||||
TOOL_DESCRIPTION[DRPath.PYTHON_TOOL] = (
|
||||
"""\
|
||||
This tool executes Python code within an isolated Code Interpreter environment. \
|
||||
Use it when the user explicitly requests code execution, data analysis, statistical computations, \
|
||||
or operations on uploaded files. The tool has access to any files the user supplied and can generate \
|
||||
new artifacts (for example, charts or datasets) that can be downloaded afterwards.\
|
||||
""".strip()
|
||||
)
|
||||
|
||||
TOOL_DESCRIPTION[
|
||||
DRPath.CLOSER
|
||||
] = f"""\
|
||||
@@ -948,6 +957,31 @@ ANSWER:
|
||||
)
|
||||
|
||||
|
||||
PYTHON_TOOL_USE_RESPONSE_PROMPT = PromptTemplate(
|
||||
f"""\
|
||||
Here is the base question that ultimately needs to be answered:
|
||||
{SEPARATOR_LINE}
|
||||
---base_question---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
You just used the `python` tool in an attempt to answer the base question. The respone from the \
|
||||
tool is:
|
||||
|
||||
{SEPARATOR_LINE}
|
||||
---tool_response---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
If any files were uploaded by the user, they are attached. If any files were generated by the \
|
||||
`python` tool, they are also attached, and prefixed with `generated_`.
|
||||
|
||||
|
||||
Please respond with a concise answer to the base question. If you believe additionl `python` \
|
||||
tool calls are necessary, please describe what HIGH LEVEL TASK we would like to accomplish next. \
|
||||
Do NOT describe any specific Python code that we would like to execute.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
TEST_INFO_COMPLETE_PROMPT = PromptTemplate(
|
||||
f"""\
|
||||
You are an expert at trying to determine whether \
|
||||
|
||||
@@ -10,6 +10,9 @@ from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.python_tool import (
|
||||
PythonTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
@@ -20,7 +23,12 @@ logger = setup_logger()
|
||||
|
||||
|
||||
BUILT_IN_TOOL_TYPES = Union[
|
||||
SearchTool, ImageGenerationTool, WebSearchTool, KnowledgeGraphTool, OktaProfileTool
|
||||
SearchTool,
|
||||
ImageGenerationTool,
|
||||
WebSearchTool,
|
||||
KnowledgeGraphTool,
|
||||
OktaProfileTool,
|
||||
PythonTool,
|
||||
]
|
||||
|
||||
# same as d09fc20a3c66_seed_builtin_tools.py
|
||||
@@ -30,6 +38,7 @@ BUILT_IN_TOOL_MAP: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {
|
||||
WebSearchTool.__name__: WebSearchTool,
|
||||
KnowledgeGraphTool.__name__: KnowledgeGraphTool,
|
||||
OktaProfileTool.__name__: OktaProfileTool,
|
||||
PythonTool.__name__: PythonTool,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@ from onyx.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from onyx.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from onyx.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@@ -55,6 +58,9 @@ from onyx.tools.tool_implementations.mcp.mcp_tool import MCPTool
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.python_tool import (
|
||||
PythonTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
@@ -280,6 +286,27 @@ def construct_tools(
|
||||
)
|
||||
]
|
||||
|
||||
# Handle Python Tool
|
||||
elif tool_cls.__name__ == PythonTool.__name__:
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
raise ValueError(
|
||||
"Code Interpreter base URL must be configured to use the Python tool"
|
||||
)
|
||||
|
||||
available_files = []
|
||||
if search_tool_config and search_tool_config.latest_query_files:
|
||||
available_files = search_tool_config.latest_query_files
|
||||
|
||||
tool_dict[db_tool_model.id] = [
|
||||
PythonTool(
|
||||
tool_id=db_tool_model.id,
|
||||
base_url=CODE_INTERPRETER_BASE_URL,
|
||||
default_timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
request_timeout_seconds=CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS,
|
||||
available_files=available_files,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle Internet Search Tool
|
||||
elif tool_cls.__name__ == WebSearchTool.__name__:
|
||||
if not internet_search_tool_config:
|
||||
|
||||
644
backend/onyx/tools/tool_implementations/python/python_tool.py
Normal file
644
backend/onyx/tools/tool_implementations/python/python_tool.py
Normal file
@@ -0,0 +1,644 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.tools.base_tool import BaseTool
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
PYTHON_TOOL_RESPONSE_ID = "python_tool_result"
|
||||
_EXECUTE_PATH = "/execute"
|
||||
_DEFAULT_RESULT_CHAR_LIMIT = 1200
|
||||
|
||||
|
||||
class PythonToolArgs(BaseModel):
|
||||
code: str = Field(
|
||||
...,
|
||||
description="Python source code to execute inside the Code Interpreter service",
|
||||
)
|
||||
stdin: str | None = Field(
|
||||
default=None,
|
||||
description="Optional standard input payload supplied to the process",
|
||||
)
|
||||
timeout_ms: int | None = Field(
|
||||
default=None,
|
||||
ge=1_000,
|
||||
description="Optional execution timeout override in milliseconds",
|
||||
)
|
||||
|
||||
|
||||
class ExecuteRequestFilePayload(BaseModel):
|
||||
path: str
|
||||
content_base64: str
|
||||
|
||||
|
||||
class ExecuteRequestPayload(BaseModel):
|
||||
code: str
|
||||
timeout_ms: int
|
||||
stdin: str | None = None
|
||||
files: list[ExecuteRequestFilePayload] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkspaceFilePayload(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
path: str
|
||||
kind: str = Field(default="file")
|
||||
content_base64: str | None = None
|
||||
mime_type: str | None = None
|
||||
size_bytes: int | None = Field(default=None, alias="size")
|
||||
|
||||
|
||||
class PythonToolArtifact(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
path: str
|
||||
kind: str
|
||||
file_id: str | None = None
|
||||
display_name: str | None = None
|
||||
mime_type: str | None = None
|
||||
size_bytes: int | None = None
|
||||
error: str | None = None
|
||||
chat_file_type: ChatFileType | None = None
|
||||
|
||||
|
||||
class PythonToolResult(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
stdout: str
|
||||
stderr: str | None = None
|
||||
exit_code: int | None = None
|
||||
execution_time_ms: int | None = None
|
||||
timeout_ms: int
|
||||
input_files: list[dict[str, Any]] = Field(default_factory=list)
|
||||
artifacts: list[PythonToolArtifact] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
_PYTHON_DESCRIPTION = """
|
||||
When you send a message containing Python code to python, it will be executed in an isolated environment.
|
||||
|
||||
The code you write will be placed into a file called `__main__.py` and run like `python __main__.py`. \
|
||||
All other files present in the conversation will also be present in this same directory. E.g. if the \
|
||||
user has uploaded three files, `analytics.csv`, `word.txt`, and `cat.png`, the directory structure will look like:
|
||||
|
||||
workspace/
|
||||
__main__.py
|
||||
analytics.csv
|
||||
word.txt
|
||||
cat.png
|
||||
|
||||
python will respond with the stdout of the `__main__.py` script as well as any new/edited files in the `workspace` directory. \
|
||||
That means, if you want to get access to some result of the script you can either: 1) `print(<result>)` or 2) save \
|
||||
a result to a file in the `workspace` directory.
|
||||
|
||||
Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
|
||||
|
||||
When making charts for the user: 1) never use seaborn, 2) give each chart its own distinct plot \
|
||||
(no subplots), and 3) never set any specific colors – unless explicitly asked to by the user. \
|
||||
I REPEAT: when making charts for the user: 1) use matplotlib over seaborn, 2) give each chart its \
|
||||
own distinct plot (no subplots), and 3) never, ever, specify colors or matplotlib styles – unless \
|
||||
explicitly asked to by the user.
|
||||
""".strip()
|
||||
|
||||
|
||||
class PythonTool(BaseTool):
|
||||
_NAME = "run_code_interpreter"
|
||||
_DESCRIPTION = _PYTHON_DESCRIPTION
|
||||
_DISPLAY_NAME = "Code Interpreter"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_id: int,
|
||||
base_url: str,
|
||||
default_timeout_ms: int,
|
||||
request_timeout_seconds: int,
|
||||
available_files: list[InMemoryChatFile] | None = None,
|
||||
) -> None:
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"Code Interpreter base URL must be configured to use the Python tool"
|
||||
)
|
||||
|
||||
self._id = tool_id
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._default_timeout_ms = default_timeout_ms
|
||||
self._request_timeout_seconds = request_timeout_seconds
|
||||
self._available_files: dict[str, InMemoryChatFile] = {
|
||||
chat_file.file_id: chat_file for chat_file in available_files or []
|
||||
}
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
return bool(CODE_INTERPRETER_BASE_URL)
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python source code to execute",
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
# Not supported for non-tool calling LLMs
|
||||
return None
|
||||
|
||||
def run(
|
||||
self,
|
||||
override_kwargs: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
raw_requested_files = kwargs.pop("files", None)
|
||||
|
||||
try:
|
||||
parsed_args = PythonToolArgs.model_validate(kwargs)
|
||||
except ValidationError as exc:
|
||||
logger.exception("Invalid arguments passed to PythonTool")
|
||||
raise ValueError("Invalid arguments supplied to Code Interpreter") from exc
|
||||
|
||||
timeout_ms = parsed_args.timeout_ms or self._default_timeout_ms
|
||||
override_files = self._coerce_override_files(override_kwargs)
|
||||
requested_files = self._load_requested_files(raw_requested_files)
|
||||
request_files, input_metadata = self._prepare_request_files(
|
||||
override_files + requested_files
|
||||
)
|
||||
|
||||
request_payload = ExecuteRequestPayload(
|
||||
code=parsed_args.code,
|
||||
timeout_ms=timeout_ms,
|
||||
stdin=parsed_args.stdin,
|
||||
files=request_files,
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self._base_url}{_EXECUTE_PATH}",
|
||||
json=request_payload.model_dump(exclude_none=True),
|
||||
timeout=self._request_timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as exc:
|
||||
logger.exception("Code Interpreter execution failed")
|
||||
raise ValueError(
|
||||
"Failed to reach the Code Interpreter service. Please try again later."
|
||||
) from exc
|
||||
|
||||
response_data = self._parse_response(response)
|
||||
artifacts = self._persist_artifacts(response_data.get("files", []))
|
||||
|
||||
python_result = PythonToolResult(
|
||||
stdout=self._ensure_text(response_data.get("stdout", "")),
|
||||
stderr=self._optional_text(response_data.get("stderr")),
|
||||
exit_code=self._safe_int(response_data.get("exit_code")),
|
||||
execution_time_ms=self._safe_int(
|
||||
response_data.get("execution_time_ms")
|
||||
or response_data.get("duration_ms")
|
||||
),
|
||||
timeout_ms=timeout_ms,
|
||||
input_files=input_metadata,
|
||||
artifacts=artifacts,
|
||||
metadata=self._extract_metadata(response_data.get("metadata")),
|
||||
)
|
||||
|
||||
yield ToolResponse(id=PYTHON_TOOL_RESPONSE_ID, response=python_result)
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
result = self._extract_result(args)
|
||||
sections: list[str] = []
|
||||
|
||||
if result.stdout:
|
||||
sections.append(self._format_section("stdout", result.stdout))
|
||||
if result.stderr:
|
||||
sections.append(self._format_section("stderr", result.stderr))
|
||||
|
||||
if result.artifacts:
|
||||
file_lines = []
|
||||
for artifact in result.artifacts:
|
||||
display_name = artifact.display_name or Path(artifact.path).name
|
||||
if artifact.error:
|
||||
file_lines.append(f"- {display_name}: {artifact.error}")
|
||||
elif artifact.file_id:
|
||||
file_lines.append(
|
||||
f"- {display_name} (file_id={artifact.file_id}, mime={artifact.mime_type or 'unknown'})"
|
||||
)
|
||||
else:
|
||||
file_lines.append(f"- {display_name}")
|
||||
sections.append("Generated files:\n" + "\n".join(file_lines))
|
||||
|
||||
return (
|
||||
"\n\n".join(sections) if sections else "Execution completed with no output."
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
return self._extract_result(args).model_dump()
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
tool_call_summary: ToolCallSummary,
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> AnswerPromptBuilder:
|
||||
updated_prompt_builder = super().build_next_prompt(
|
||||
prompt_builder, tool_call_summary, tool_responses, using_tool_calling_llm
|
||||
)
|
||||
|
||||
result = self._extract_result(tool_responses)
|
||||
if result.artifacts:
|
||||
existing_ids = {
|
||||
file.file_id for file in updated_prompt_builder.raw_user_uploaded_files
|
||||
}
|
||||
for artifact in result.artifacts:
|
||||
if not artifact.file_id:
|
||||
continue
|
||||
chat_file = self._available_files.get(artifact.file_id)
|
||||
if not chat_file or chat_file.file_id in existing_ids:
|
||||
continue
|
||||
if chat_file.file_type.is_text_file():
|
||||
updated_prompt_builder.raw_user_uploaded_files.append(chat_file)
|
||||
existing_ids.add(chat_file.file_id)
|
||||
return updated_prompt_builder
|
||||
|
||||
def _load_requested_files(self, raw_files: Any) -> list[InMemoryChatFile]:
|
||||
if not raw_files:
|
||||
return []
|
||||
|
||||
if not isinstance(raw_files, list):
|
||||
logger.warning(
|
||||
"Ignoring non-list 'files' argument passed to PythonTool: %r",
|
||||
type(raw_files),
|
||||
)
|
||||
return []
|
||||
|
||||
loaded_files: list[InMemoryChatFile] = []
|
||||
for descriptor in raw_files:
|
||||
file_id: str | None = None
|
||||
desired_name: str | None = None
|
||||
|
||||
if isinstance(descriptor, dict):
|
||||
raw_id = descriptor.get("id") or descriptor.get("file_id")
|
||||
if raw_id is not None:
|
||||
file_id = str(raw_id)
|
||||
desired_name = (
|
||||
descriptor.get("path")
|
||||
or descriptor.get("name")
|
||||
or descriptor.get("filename")
|
||||
)
|
||||
elif isinstance(descriptor, str):
|
||||
file_id = descriptor
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping unsupported file descriptor passed to PythonTool: %r",
|
||||
descriptor,
|
||||
)
|
||||
continue
|
||||
|
||||
if not file_id:
|
||||
logger.warning("Skipping file descriptor without an id: %r", descriptor)
|
||||
continue
|
||||
|
||||
chat_file = self._available_files.get(file_id)
|
||||
if not chat_file:
|
||||
chat_file = self._load_file_from_store(file_id)
|
||||
|
||||
if not chat_file:
|
||||
logger.warning("Requested file '%s' could not be found", file_id)
|
||||
continue
|
||||
|
||||
if desired_name and desired_name != chat_file.filename:
|
||||
chat_file = InMemoryChatFile(
|
||||
file_id=chat_file.file_id,
|
||||
content=chat_file.content,
|
||||
filename=desired_name,
|
||||
file_type=chat_file.file_type,
|
||||
)
|
||||
|
||||
loaded_files.append(chat_file)
|
||||
|
||||
return loaded_files
|
||||
|
||||
def _load_file_from_store(self, file_id: str) -> InMemoryChatFile | None:
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.exception("Failed to obtain file store instance: %s", exc)
|
||||
return None
|
||||
|
||||
try:
|
||||
file_record = file_store.read_file_record(file_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to read file record for '%s' from file store", file_id
|
||||
)
|
||||
return None
|
||||
|
||||
if not file_record:
|
||||
logger.warning("No file record found for requested file '%s'", file_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to read file content for '%s' from file store", file_id
|
||||
)
|
||||
return None
|
||||
|
||||
with closing(file_io):
|
||||
data = file_io.read() if hasattr(file_io, "read") else file_io
|
||||
|
||||
if not isinstance(data, (bytes, bytearray)):
|
||||
logger.warning("File store returned non-bytes payload for '%s'", file_id)
|
||||
return None
|
||||
|
||||
mime_type = (
|
||||
getattr(file_record, "file_type", None) or "application/octet-stream"
|
||||
)
|
||||
display_name = getattr(file_record, "display_name", None) or file_id
|
||||
chat_file_type = mime_type_to_chat_file_type(mime_type)
|
||||
|
||||
return InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=bytes(data),
|
||||
filename=display_name,
|
||||
file_type=chat_file_type,
|
||||
)
|
||||
|
||||
def _coerce_override_files(
|
||||
self, override_kwargs: dict[str, Any] | None
|
||||
) -> list[InMemoryChatFile]:
|
||||
if not override_kwargs:
|
||||
return []
|
||||
|
||||
raw_files = override_kwargs.get("files")
|
||||
if not raw_files:
|
||||
return []
|
||||
|
||||
coerced_files: list[InMemoryChatFile] = []
|
||||
for raw_file in raw_files:
|
||||
if isinstance(raw_file, InMemoryChatFile):
|
||||
coerced_files.append(raw_file)
|
||||
else:
|
||||
logger.warning(
|
||||
"Ignoring non InMemoryChatFile entry passed to PythonTool override: %r",
|
||||
type(raw_file),
|
||||
)
|
||||
return coerced_files
|
||||
|
||||
def _prepare_request_files(
|
||||
self, chat_files: list[InMemoryChatFile]
|
||||
) -> tuple[list[ExecuteRequestFilePayload], list[dict[str, Any]]]:
|
||||
request_files: list[ExecuteRequestFilePayload] = []
|
||||
input_metadata: list[dict[str, Any]] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for chat_file in chat_files:
|
||||
file_id = str(chat_file.file_id)
|
||||
if file_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(file_id)
|
||||
|
||||
self._available_files[file_id] = chat_file
|
||||
target_path = self._resolve_target_path(chat_file)
|
||||
|
||||
request_files.append(
|
||||
ExecuteRequestFilePayload(
|
||||
path=target_path,
|
||||
content_base64=base64.b64encode(chat_file.content).decode("utf-8"),
|
||||
)
|
||||
)
|
||||
|
||||
input_metadata.append(
|
||||
{
|
||||
"id": file_id,
|
||||
"path": target_path,
|
||||
"name": chat_file.filename or file_id,
|
||||
"chat_file_type": chat_file.file_type.value,
|
||||
}
|
||||
)
|
||||
|
||||
return request_files, input_metadata
|
||||
|
||||
def _persist_artifacts(self, files: list[Any]) -> list[PythonToolArtifact]:
|
||||
artifacts: list[PythonToolArtifact] = []
|
||||
if not files:
|
||||
return artifacts
|
||||
|
||||
file_store = get_default_file_store()
|
||||
|
||||
for raw_file in files:
|
||||
try:
|
||||
workspace_file = WorkspaceFilePayload.model_validate(raw_file)
|
||||
except ValidationError as exc:
|
||||
logger.warning("Skipping malformed workspace file entry: %s", exc)
|
||||
continue
|
||||
|
||||
artifact = PythonToolArtifact(
|
||||
path=workspace_file.path, kind=workspace_file.kind
|
||||
)
|
||||
if workspace_file.kind != "file" or not workspace_file.content_base64:
|
||||
artifacts.append(artifact)
|
||||
continue
|
||||
|
||||
try:
|
||||
binary = base64.b64decode(workspace_file.content_base64)
|
||||
except binascii.Error as exc:
|
||||
artifact.error = f"Failed to decode base64 content: {exc}"
|
||||
artifacts.append(artifact)
|
||||
continue
|
||||
|
||||
mime_type = self._infer_mime_type(workspace_file, binary)
|
||||
display_name = Path(workspace_file.path).name or workspace_file.path
|
||||
|
||||
try:
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(binary),
|
||||
display_name=display_name,
|
||||
file_origin=FileOrigin.GENERATED_REPORT,
|
||||
file_type=mime_type,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Failed to persist Code Interpreter artifact '%s'",
|
||||
workspace_file.path,
|
||||
)
|
||||
artifact.error = f"Failed to persist artifact: {exc}"
|
||||
artifacts.append(artifact)
|
||||
continue
|
||||
|
||||
chat_file_type = mime_type_to_chat_file_type(mime_type)
|
||||
chat_file = InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=binary,
|
||||
filename=display_name,
|
||||
file_type=chat_file_type,
|
||||
)
|
||||
self._available_files[file_id] = chat_file
|
||||
|
||||
artifact.file_id = file_id
|
||||
artifact.display_name = display_name
|
||||
artifact.mime_type = mime_type
|
||||
artifact.size_bytes = len(binary)
|
||||
artifact.chat_file_type = chat_file_type
|
||||
artifacts.append(artifact)
|
||||
|
||||
return artifacts
|
||||
|
||||
def _parse_response(self, response: requests.Response) -> dict[str, Any]:
|
||||
try:
|
||||
response_data = response.json()
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.exception("Code Interpreter returned invalid JSON")
|
||||
raise ValueError(
|
||||
"Code Interpreter returned an invalid JSON response"
|
||||
) from exc
|
||||
|
||||
error_payload = response_data.get("error")
|
||||
if error_payload:
|
||||
if isinstance(error_payload, dict):
|
||||
message = error_payload.get("message") or json.dumps(error_payload)
|
||||
else:
|
||||
message = str(error_payload)
|
||||
raise ValueError(f"Code Interpreter reported an error: {message}")
|
||||
|
||||
return response_data
|
||||
|
||||
def _extract_result(self, responses: tuple[ToolResponse, ...]) -> PythonToolResult:
|
||||
for response in responses:
|
||||
if response.id == PYTHON_TOOL_RESPONSE_ID:
|
||||
if isinstance(response.response, PythonToolResult):
|
||||
return response.response
|
||||
return PythonToolResult.model_validate(response.response)
|
||||
raise ValueError("No python tool result found in tool responses")
|
||||
|
||||
def _resolve_target_path(self, chat_file: InMemoryChatFile) -> str:
|
||||
if chat_file.filename:
|
||||
return Path(chat_file.filename).name
|
||||
return chat_file.file_id
|
||||
|
||||
def _infer_mime_type(
|
||||
self, workspace_file: WorkspaceFilePayload, binary: bytes
|
||||
) -> str:
|
||||
if workspace_file.mime_type:
|
||||
return workspace_file.mime_type
|
||||
|
||||
guess, _ = mimetypes.guess_type(workspace_file.path)
|
||||
if guess:
|
||||
return guess
|
||||
|
||||
if self._looks_like_text(binary):
|
||||
return "text/plain"
|
||||
|
||||
return "application/octet-stream"
|
||||
|
||||
@staticmethod
|
||||
def _ensure_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _optional_text(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text_value = str(value)
|
||||
return text_value if text_value else None
|
||||
|
||||
@staticmethod
|
||||
def _safe_int(value: Any) -> int | None:
|
||||
try:
|
||||
return int(value) if value is not None else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_metadata(value: Any) -> dict[str, Any] | None:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_text(payload: bytes, sample_size: int = 2048) -> bool:
|
||||
sample = payload[:sample_size]
|
||||
try:
|
||||
sample.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _format_section(title: str, body: str) -> str:
|
||||
truncated = (
|
||||
body
|
||||
if len(body) <= _DEFAULT_RESULT_CHAR_LIMIT
|
||||
else body[: _DEFAULT_RESULT_CHAR_LIMIT - 3] + "..."
|
||||
)
|
||||
return f"{title}:\n{truncated}"
|
||||
@@ -88,6 +88,7 @@ import { FilePickerModal } from "@/app/chat/my-documents/components/FilePicker";
|
||||
import { useDocumentsContext } from "@/app/chat/my-documents/DocumentsContext";
|
||||
|
||||
import {
|
||||
PYTHON_TOOL_ID,
|
||||
IMAGE_GENERATION_TOOL_ID,
|
||||
SEARCH_TOOL_ID,
|
||||
WEB_SEARCH_TOOL_ID,
|
||||
@@ -111,6 +112,10 @@ function findWebSearchTool(tools: ToolSnapshot[]) {
|
||||
return tools.find((tool) => tool.in_code_tool_id === WEB_SEARCH_TOOL_ID);
|
||||
}
|
||||
|
||||
function findPythonTool(tools: ToolSnapshot[]) {
|
||||
return tools.find((tool) => tool.in_code_tool_id === PYTHON_TOOL_ID);
|
||||
}
|
||||
|
||||
function SubLabel({ children }: { children: string | JSX.Element }) {
|
||||
return (
|
||||
<div
|
||||
@@ -213,13 +218,15 @@ export function AssistantEditor({
|
||||
const searchTool = findSearchTool(tools);
|
||||
const imageGenerationTool = findImageGenerationTool(tools);
|
||||
const webSearchTool = findWebSearchTool(tools);
|
||||
const pythonTool = findPythonTool(tools);
|
||||
|
||||
// Separate MCP tools from regular custom tools
|
||||
const allCustomTools = tools.filter(
|
||||
(tool) =>
|
||||
tool.in_code_tool_id !== searchTool?.in_code_tool_id &&
|
||||
tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id &&
|
||||
tool.in_code_tool_id !== webSearchTool?.in_code_tool_id
|
||||
tool.in_code_tool_id !== webSearchTool?.in_code_tool_id &&
|
||||
tool.in_code_tool_id !== pythonTool?.in_code_tool_id
|
||||
);
|
||||
|
||||
const mcpTools = allCustomTools.filter((tool) => tool.mcp_server_id);
|
||||
@@ -290,6 +297,7 @@ export function AssistantEditor({
|
||||
...mcpTools, // Include MCP tools for form logic
|
||||
...(searchTool ? [searchTool] : []),
|
||||
...(imageGenerationTool ? [imageGenerationTool] : []),
|
||||
...(pythonTool ? [pythonTool] : []),
|
||||
...(webSearchTool ? [webSearchTool] : []),
|
||||
];
|
||||
const enabledToolsMap: { [key: number]: boolean } = {};
|
||||
@@ -1211,6 +1219,16 @@ export function AssistantEditor({
|
||||
</>
|
||||
)}
|
||||
|
||||
{pythonTool && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
name={`enabled_tools_map.${pythonTool.id}`}
|
||||
label={pythonTool.display_name}
|
||||
subtext="Execute Python code against staged files using the Code Interpreter sandbox"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{webSearchTool && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
|
||||
@@ -272,6 +272,9 @@ function ToolToggle({
|
||||
if (tool.in_code_tool_id === "ImageGenerationTool") {
|
||||
return "Add an OpenAI LLM provider with an API key under Admin → Configuration → LLM.";
|
||||
}
|
||||
if (tool.in_code_tool_id === "PythonTool") {
|
||||
return "Set CODE_INTERPRETER_BASE_URL on the server and restart to enable Code Interpreter.";
|
||||
}
|
||||
return "Not configured.";
|
||||
})();
|
||||
return (
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import { FiImage, FiSearch } from "react-icons/fi";
|
||||
import { FiImage, FiSearch, FiTerminal } from "react-icons/fi";
|
||||
import { Persona } from "../admin/assistants/interfaces";
|
||||
import { SEARCH_TOOL_ID } from "../chat/components/tools/constants";
|
||||
import {
|
||||
PYTHON_TOOL_ID,
|
||||
SEARCH_TOOL_ID,
|
||||
} from "../chat/components/tools/constants";
|
||||
|
||||
export function AssistantTools({
|
||||
assistant,
|
||||
@@ -69,6 +72,26 @@ export function AssistantTools({
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
} else if (tool.name === PYTHON_TOOL_ID) {
|
||||
return (
|
||||
<div
|
||||
key={ind}
|
||||
className={`
|
||||
px-1.5
|
||||
py-1
|
||||
rounded-lg
|
||||
border
|
||||
border-border
|
||||
w-fit
|
||||
flex
|
||||
${list ? "bg-background-125" : "bg-background-100"}`}
|
||||
>
|
||||
<div className="flex items-center gap-x-1">
|
||||
<FiTerminal className="ml-1 my-auto h-3 w-3" />
|
||||
{tool.display_name || "Code Interpreter"}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<div
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
export const SEARCH_TOOL_NAME = "run_search";
|
||||
export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search";
|
||||
export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation";
|
||||
export const PYTHON_TOOL_NAME = "run_code_interpreter";
|
||||
|
||||
// In-code tool IDs that also correspond to the tool's name when associated with a persona
|
||||
export const SEARCH_TOOL_ID = "SearchTool";
|
||||
export const IMAGE_GENERATION_TOOL_ID = "ImageGenerationTool";
|
||||
export const WEB_SEARCH_TOOL_ID = "WebSearchTool";
|
||||
export const PYTHON_TOOL_ID = "PythonTool";
|
||||
|
||||
@@ -32,6 +32,14 @@ const isImageGenerationTool = (tool: ToolSnapshot): boolean => {
|
||||
);
|
||||
};
|
||||
|
||||
const isPythonTool = (tool: ToolSnapshot): boolean => {
|
||||
return (
|
||||
tool.in_code_tool_id === "PythonTool" ||
|
||||
tool.in_code_tool_id === "AnalysisTool" ||
|
||||
tool.display_name?.toLowerCase().includes("code interpreter")
|
||||
);
|
||||
};
|
||||
|
||||
const isKnowledgeGraphTool = (tool: ToolSnapshot): boolean => {
|
||||
return (
|
||||
tool.in_code_tool_id === "KnowledgeGraphTool" ||
|
||||
@@ -55,6 +63,8 @@ export function getIconForAction(
|
||||
return GlobeIcon;
|
||||
} else if (isImageGenerationTool(action)) {
|
||||
return ImageIcon;
|
||||
} else if (isPythonTool(action)) {
|
||||
return CpuIcon;
|
||||
} else if (isKnowledgeGraphTool(action)) {
|
||||
return DatabaseIcon;
|
||||
} else if (isOktaProfileTool(action)) {
|
||||
|
||||
Reference in New Issue
Block a user