mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-28 13:15:44 +00:00
Compare commits
5 Commits
fix-playwr
...
add-code-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68a484ae73 | ||
|
|
f4d135d710 | ||
|
|
6094f70ac8 | ||
|
|
a90e58b39b | ||
|
|
e82e3141ed |
73
backend/alembic/versions/1c3f8a7b5d4e_add_python_tool.py
Normal file
73
backend/alembic/versions/1c3f8a7b5d4e_add_python_tool.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""add_python_tool
|
||||
|
||||
Revision ID: 1c3f8a7b5d4e
|
||||
Revises: 505c488f6662
|
||||
Create Date: 2025-02-14 00:00:00
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1c3f8a7b5d4e"
|
||||
down_revision = "505c488f6662"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
PYTHON_TOOL = {
|
||||
"name": "PythonTool",
|
||||
"display_name": "Code Interpreter",
|
||||
"description": (
|
||||
"The Code Interpreter Action lets assistants execute Python in an isolated runtime. "
|
||||
"It can process staged files, read and write artifacts, stream stdout and stderr, "
|
||||
"and return generated outputs for the chat session."
|
||||
),
|
||||
"in_code_tool_id": "PythonTool",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
try:
|
||||
existing = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
PYTHON_TOOL,
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
PYTHON_TOOL,
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
PYTHON_TOOL,
|
||||
)
|
||||
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Do not delete the tool entry on downgrade; leaving it is safe and keeps migrations idempotent.
|
||||
pass
|
||||
@@ -39,6 +39,7 @@ def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
DRPath.WEB_SEARCH,
|
||||
DRPath.KNOWLEDGE_GRAPH,
|
||||
DRPath.IMAGE_GENERATION,
|
||||
DRPath.PYTHON_TOOL,
|
||||
)
|
||||
and len(state.query_list) == 0
|
||||
):
|
||||
|
||||
@@ -21,6 +21,7 @@ AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
DRPath.WEB_SEARCH: 1.5,
|
||||
DRPath.IMAGE_GENERATION: 3.0,
|
||||
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
DRPath.PYTHON_TOOL: 2.0,
|
||||
DRPath.CLOSER: 0.0,
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class DRPath(str, Enum):
|
||||
WEB_SEARCH = "Web Search"
|
||||
IMAGE_GENERATION = "Image Generation"
|
||||
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
|
||||
PYTHON_TOOL = "Python"
|
||||
CLOSER = "Closer"
|
||||
LOGGER = "Logger"
|
||||
END = "End"
|
||||
|
||||
@@ -26,6 +26,9 @@ from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
dr_kg_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_graph_builder import (
|
||||
dr_python_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
|
||||
dr_ws_graph_builder,
|
||||
)
|
||||
@@ -58,12 +61,15 @@ def dr_graph_builder() -> StateGraph:
|
||||
image_generation_graph = dr_image_generation_graph_builder().compile()
|
||||
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
|
||||
|
||||
python_tool_graph = dr_python_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.PYTHON_TOOL, python_tool_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
graph.add_node(DRPath.CLOSER, closer)
|
||||
graph.add_node(DRPath.LOGGER, logging)
|
||||
|
||||
@@ -81,6 +87,7 @@ def dr_graph_builder() -> StateGraph:
|
||||
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.PYTHON_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
|
||||
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
|
||||
|
||||
@@ -69,6 +69,7 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
@@ -134,6 +135,9 @@ def _get_available_tools(
|
||||
continue
|
||||
llm_path = DRPath.KNOWLEDGE_GRAPH.value
|
||||
path = DRPath.KNOWLEDGE_GRAPH
|
||||
elif isinstance(tool, PythonTool):
|
||||
llm_path = DRPath.PYTHON_TOOL.value
|
||||
path = DRPath.PYTHON_TOOL
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
llm_path = DRPath.IMAGE_GENERATION.value
|
||||
path = DRPath.IMAGE_GENERATION
|
||||
@@ -778,6 +782,6 @@ def clarifier(
|
||||
active_source_types_descriptions="\n".join(active_source_types_descriptions),
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
uploaded_test_context=uploaded_text_context,
|
||||
uploaded_text_context=uploaded_text_context,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
)
|
||||
|
||||
@@ -140,7 +140,7 @@ def orchestrator(
|
||||
|
||||
available_tools = state.available_tools or {}
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
uploaded_context = state.uploaded_text_context or ""
|
||||
uploaded_image_context = state.uploaded_image_context or []
|
||||
|
||||
questions = [
|
||||
|
||||
@@ -228,7 +228,7 @@ def closer(
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
uploaded_context = state.uploaded_text_context or ""
|
||||
|
||||
clarification = state.clarification
|
||||
prompt_question = get_prompt_question(base_question, clarification)
|
||||
|
||||
@@ -46,7 +46,7 @@ class OrchestrationSetup(OrchestrationUpdate):
|
||||
active_source_types_descriptions: str | None = None
|
||||
assistant_system_prompt: str | None = None
|
||||
assistant_task_prompt: str | None = None
|
||||
uploaded_test_context: str | None = None
|
||||
uploaded_text_context: str | None = None
|
||||
uploaded_image_context: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Python Tool sub-agent for deep research."""
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def python_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""Log the beginning of a Python Tool branch."""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(
|
||||
f"Python Tool branch start for iteration {iteration_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="python_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,257 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import PYTHON_TOOL_USE_RESPONSE_PROMPT
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonToolResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _serialize_chat_files(chat_files: list[InMemoryChatFile]) -> list[dict[str, Any]]:
|
||||
serialized_files: list[dict[str, Any]] = []
|
||||
for chat_file in chat_files:
|
||||
file_payload: dict[str, Any] = {
|
||||
"id": str(chat_file.file_id),
|
||||
"name": chat_file.filename,
|
||||
"type": chat_file.file_type.value,
|
||||
}
|
||||
if chat_file.file_type == ChatFileType.IMAGE:
|
||||
file_payload["content"] = chat_file.to_base64()
|
||||
file_payload["is_base64"] = True
|
||||
elif chat_file.file_type.is_text_file():
|
||||
file_payload["content"] = chat_file.content.decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
file_payload["is_base64"] = False
|
||||
else:
|
||||
file_payload["content"] = base64.b64encode(chat_file.content).decode(
|
||||
"utf-8"
|
||||
)
|
||||
file_payload["is_base64"] = True
|
||||
serialized_files.append(file_payload)
|
||||
|
||||
return serialized_files
|
||||
|
||||
|
||||
def python_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""Execute the Python Tool with any files supplied by the user."""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
tool_key = state.tools_used[-1]
|
||||
python_tool_info = state.available_tools[tool_key]
|
||||
python_tool = cast(PythonTool | None, python_tool_info.tool_object)
|
||||
|
||||
if python_tool is None:
|
||||
raise ValueError("python_tool is not set")
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
files = graph_config.inputs.files
|
||||
|
||||
logger.debug(
|
||||
"Tool call start for %s %s.%s at %s",
|
||||
python_tool.llm_name,
|
||||
iteration_nr,
|
||||
parallelization_nr,
|
||||
datetime.now(),
|
||||
)
|
||||
|
||||
tool_args: dict[str, Any] | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_description=python_tool_info.description,
|
||||
)
|
||||
|
||||
content_with_files = build_content_with_imgs(
|
||||
message=tool_use_prompt,
|
||||
files=files,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
tool_prompt_message: dict[str, Any] = {
|
||||
"role": "user",
|
||||
"content": content_with_files,
|
||||
}
|
||||
if files:
|
||||
tool_prompt_message["files"] = _serialize_chat_files(files)
|
||||
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke(
|
||||
[tool_prompt_message],
|
||||
tools=[python_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
|
||||
if isinstance(tool_calling_msg, AIMessage) and tool_calling_msg.tool_calls:
|
||||
tool_args = tool_calling_msg.tool_calls[0].get("args")
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call for Python Tool")
|
||||
|
||||
if tool_args is None:
|
||||
tool_args = python_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
|
||||
if "files" in tool_args:
|
||||
tool_args = {key: value for key, value in tool_args.items() if key != "files"}
|
||||
|
||||
override_kwargs = {"files": files or []}
|
||||
|
||||
tool_responses = list(python_tool.run(override_kwargs=override_kwargs, **tool_args))
|
||||
|
||||
python_tool_result: PythonToolResult | None = None
|
||||
for response in tool_responses:
|
||||
if isinstance(response.response, PythonToolResult):
|
||||
python_tool_result = response.response
|
||||
break
|
||||
|
||||
if python_tool_result is None:
|
||||
raise ValueError("Python tool did not return a valid result")
|
||||
|
||||
final_result = python_tool.final_result(*tool_responses)
|
||||
tool_result_str = json.dumps(final_result, ensure_ascii=False)
|
||||
|
||||
tool_summary_prompt = PYTHON_TOOL_USE_RESPONSE_PROMPT.build(
|
||||
base_question=base_question,
|
||||
tool_response=tool_result_str,
|
||||
)
|
||||
|
||||
initial_files = list(files or [])
|
||||
generated_files: list[InMemoryChatFile] = []
|
||||
for artifact in python_tool_result.artifacts:
|
||||
if not artifact.file_id:
|
||||
continue
|
||||
|
||||
chat_file = python_tool._available_files.get(artifact.file_id)
|
||||
if not chat_file:
|
||||
logger.warning(
|
||||
"Generated artifact with id %s not found in available files",
|
||||
artifact.file_id,
|
||||
)
|
||||
continue
|
||||
|
||||
filename = (
|
||||
chat_file.filename
|
||||
or artifact.display_name
|
||||
or artifact.path
|
||||
or str(artifact.file_id)
|
||||
)
|
||||
filename = Path(filename).name or str(artifact.file_id)
|
||||
if not filename.startswith("generated_"):
|
||||
filename = f"generated_{filename}"
|
||||
|
||||
generated_files.append(
|
||||
InMemoryChatFile(
|
||||
file_id=chat_file.file_id,
|
||||
content=chat_file.content,
|
||||
file_type=chat_file.file_type,
|
||||
filename=filename,
|
||||
)
|
||||
)
|
||||
|
||||
summary_files = initial_files + generated_files
|
||||
summary_content = build_content_with_imgs(
|
||||
message=tool_summary_prompt,
|
||||
files=summary_files,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
summary_message: dict[str, Any] = {
|
||||
"role": "user",
|
||||
"content": summary_content,
|
||||
}
|
||||
if summary_files:
|
||||
summary_message["files"] = _serialize_chat_files(summary_files)
|
||||
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke(
|
||||
[summary_message],
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
).content
|
||||
).strip()
|
||||
|
||||
artifact_file_ids = [
|
||||
artifact.file_id
|
||||
for artifact in python_tool_result.artifacts
|
||||
if artifact.file_id
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
"Tool call end for %s %s.%s at %s",
|
||||
python_tool.llm_name,
|
||||
iteration_nr,
|
||||
parallelization_nr,
|
||||
datetime.now(),
|
||||
)
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=python_tool.llm_name,
|
||||
tool_id=python_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
response_type="json",
|
||||
data=final_result,
|
||||
file_ids=artifact_file_ids or None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="python_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,76 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def python_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""Stream the Python Tool result back to the client."""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
current_step_nr = state.current_step_nr
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
for new_update in new_updates:
|
||||
if not new_update.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolStart(
|
||||
tool_name=new_update.tool,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolDelta(
|
||||
tool_name=new_update.tool,
|
||||
response_type=new_update.response_type,
|
||||
data=new_update.data,
|
||||
file_ids=new_update.file_ids,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="python_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
"""Forward the current query to the Python Tool executor."""
|
||||
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(state.query_list[:1])
|
||||
]
|
||||
@@ -0,0 +1,38 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_1_branch import (
|
||||
python_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_2_act import (
|
||||
python_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_3_reduce import (
|
||||
python_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_python_tool_graph_builder() -> StateGraph:
|
||||
"""LangGraph graph builder for the Python Tool sub-agent."""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
graph.add_node("branch", python_tool_branch)
|
||||
graph.add_node("act", python_tool_act)
|
||||
graph.add_node("reducer", python_tool_reducer)
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -758,6 +758,24 @@ AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")
|
||||
# configurable image model
|
||||
IMAGE_MODEL_NAME = os.environ.get("IMAGE_MODEL_NAME", "gpt-image-1")
|
||||
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
|
||||
_CODE_INTERPRETER_DEFAULT_TIMEOUT_MS_RAW = os.environ.get(
|
||||
"CODE_INTERPRETER_DEFAULT_TIMEOUT_MS"
|
||||
)
|
||||
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = (
|
||||
int(_CODE_INTERPRETER_DEFAULT_TIMEOUT_MS_RAW)
|
||||
if _CODE_INTERPRETER_DEFAULT_TIMEOUT_MS_RAW
|
||||
else 30_000
|
||||
)
|
||||
_CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS_RAW = os.environ.get(
|
||||
"CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS"
|
||||
)
|
||||
CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS = (
|
||||
int(_CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS_RAW)
|
||||
if _CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS_RAW
|
||||
else 30
|
||||
)
|
||||
|
||||
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
|
||||
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -1029,7 +1029,7 @@ class SharepointConnector(
|
||||
|
||||
# Filter pages based on time window if specified
|
||||
if start is not None or end is not None:
|
||||
filtered_pages = []
|
||||
filtered_pages: list[dict[str, Any]] = []
|
||||
for page in all_pages:
|
||||
page_modified = page.get("lastModifiedDateTime")
|
||||
if page_modified:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -14,6 +15,9 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
@@ -47,14 +51,30 @@ class ZendeskCredentialsNotSetUpError(PermissionError):
|
||||
|
||||
|
||||
class ZendeskClient:
|
||||
def __init__(self, subdomain: str, email: str, token: str):
|
||||
def __init__(
|
||||
self,
|
||||
subdomain: str,
|
||||
email: str,
|
||||
token: str,
|
||||
calls_per_minute: int | None = None,
|
||||
):
|
||||
self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
|
||||
self.auth = (f"{email}/token", token)
|
||||
self.make_request = request_with_rate_limit(self, calls_per_minute)
|
||||
|
||||
|
||||
def request_with_rate_limit(
|
||||
client: ZendeskClient, max_calls_per_minute: int | None = None
|
||||
) -> Callable[[str, dict[str, Any]], dict[str, Any]]:
|
||||
@retry_builder()
|
||||
def make_request(self, endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
@(
|
||||
rate_limit_builder(max_calls=max_calls_per_minute, period=60)
|
||||
if max_calls_per_minute
|
||||
else lambda x: x
|
||||
)
|
||||
def make_request(endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
|
||||
f"{client.base_url}/{endpoint}", auth=client.auth, params=params
|
||||
)
|
||||
|
||||
if response.status_code == 429:
|
||||
@@ -72,6 +92,8 @@ class ZendeskClient:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
return make_request
|
||||
|
||||
|
||||
class ZendeskPageResponse(BaseModel):
|
||||
data: list[dict[str, Any]]
|
||||
@@ -359,11 +381,13 @@ class ZendeskConnector(
|
||||
def __init__(
|
||||
self,
|
||||
content_type: str = "articles",
|
||||
calls_per_minute: int | None = None,
|
||||
) -> None:
|
||||
self.content_type = content_type
|
||||
self.subdomain = ""
|
||||
# Fetch all tags ahead of time
|
||||
self.content_tags: dict[str, str] = {}
|
||||
self.calls_per_minute = calls_per_minute
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# Subdomain is actually the whole URL
|
||||
@@ -375,7 +399,10 @@ class ZendeskConnector(
|
||||
self.subdomain = subdomain
|
||||
|
||||
self.client = ZendeskClient(
|
||||
subdomain, credentials["zendesk_email"], credentials["zendesk_token"]
|
||||
subdomain,
|
||||
credentials["zendesk_email"],
|
||||
credentials["zendesk_token"],
|
||||
calls_per_minute=self.calls_per_minute,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -72,7 +72,6 @@ def _get_answer(
|
||||
else None
|
||||
)
|
||||
research_type = ResearchType(eval_input.get("research_type", "THOUGHTFUL"))
|
||||
print(eval_input)
|
||||
request = prepare_chat_message_request(
|
||||
message_text=eval_input["message"],
|
||||
user=user,
|
||||
|
||||
@@ -10,11 +10,7 @@ from typing import Any
|
||||
|
||||
import braintrust
|
||||
import requests
|
||||
from braintrust_langchain import BraintrustCallbackHandler
|
||||
from braintrust_langchain import set_global_handler
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import BRAINTRUST_PROJECT
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
||||
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
@@ -22,6 +18,7 @@ from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalationAck
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
from onyx.evals.tracing import setup_braintrust
|
||||
|
||||
|
||||
def setup_session_factory() -> None:
|
||||
@@ -32,15 +29,6 @@ def setup_session_factory() -> None:
|
||||
)
|
||||
|
||||
|
||||
def setup_braintrust() -> None:
|
||||
braintrust.init_logger(
|
||||
project=BRAINTRUST_PROJECT,
|
||||
api_key=BRAINTRUST_API_KEY,
|
||||
)
|
||||
handler = BraintrustCallbackHandler()
|
||||
set_global_handler(handler)
|
||||
|
||||
|
||||
def load_data_local(
|
||||
local_data_path: str,
|
||||
) -> list[dict[str, dict[str, str]]]:
|
||||
@@ -67,6 +55,7 @@ def run_local(
|
||||
EvalationAck: The evaluation result
|
||||
"""
|
||||
setup_session_factory()
|
||||
setup_braintrust()
|
||||
|
||||
if search_permissions_email is None:
|
||||
raise ValueError("search_permissions_email is required for local evaluation")
|
||||
|
||||
@@ -32,7 +32,7 @@ class EvalConfigurationOptions(BaseModel):
|
||||
llm: LLMOverride = LLMOverride(
|
||||
model_provider="Default",
|
||||
model_version="gpt-4.1",
|
||||
temperature=0.5,
|
||||
temperature=0.0,
|
||||
)
|
||||
search_permissions_email: str
|
||||
dataset_name: str
|
||||
|
||||
35
backend/onyx/evals/tracing.py
Normal file
35
backend/onyx/evals/tracing.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Any
|
||||
|
||||
import braintrust
|
||||
from braintrust_langchain import set_global_handler
|
||||
from braintrust_langchain.callbacks import BraintrustCallbackHandler
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import BRAINTRUST_PROJECT
|
||||
|
||||
MASKING_LENGTH = 20000
|
||||
|
||||
|
||||
def _truncate_str(s: str) -> str:
|
||||
tail = MASKING_LENGTH // 5
|
||||
head = MASKING_LENGTH - tail
|
||||
return f"{s[:head]}…{s[-tail:]}[TRUNCATED {len(s)} chars to {MASKING_LENGTH}]"
|
||||
|
||||
|
||||
def _mask(data: Any) -> Any:
|
||||
"""Mask data if it exceeds the maximum length threshold."""
|
||||
if len(str(data)) <= MASKING_LENGTH:
|
||||
return data
|
||||
return _truncate_str(str(data))
|
||||
|
||||
|
||||
def setup_braintrust() -> None:
|
||||
"""Initialize Braintrust logger and set up global callback handler."""
|
||||
|
||||
braintrust.init_logger(
|
||||
project=BRAINTRUST_PROJECT,
|
||||
api_key=BRAINTRUST_API_KEY,
|
||||
)
|
||||
braintrust.set_masking_function(_mask)
|
||||
handler = BraintrustCallbackHandler()
|
||||
set_global_handler(handler)
|
||||
@@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
@@ -17,8 +18,10 @@ from typing import NamedTuple
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import chardet
|
||||
import openpyxl
|
||||
from markitdown import FileConversionException
|
||||
from markitdown import MarkItDown
|
||||
from markitdown import StreamInfo
|
||||
from markitdown import UnsupportedFormatException
|
||||
from PIL import Image
|
||||
from pypdf import PdfReader
|
||||
@@ -30,6 +33,8 @@ from onyx.file_processing.file_validation import TEXT_MIME_TYPE
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import unstructured_to_text
|
||||
from onyx.utils.file_types import PRESENTATION_MIME_TYPE
|
||||
from onyx.utils.file_types import WORD_PROCESSING_MIME_TYPE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -80,6 +85,20 @@ IMAGE_MEDIA_TYPES = [
|
||||
"image/webp",
|
||||
]
|
||||
|
||||
_MARKITDOWN_CONVERTER: MarkItDown | None = None
|
||||
|
||||
KNOWN_OPENPYXL_BUGS = [
|
||||
"Value must be either numerical or a string containing a wildcard",
|
||||
"File contains no valid workbook part",
|
||||
]
|
||||
|
||||
|
||||
def get_markitdown_converter() -> MarkItDown:
|
||||
global _MARKITDOWN_CONVERTER
|
||||
if _MARKITDOWN_CONVERTER is None:
|
||||
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
|
||||
return _MARKITDOWN_CONVERTER
|
||||
|
||||
|
||||
class OnyxExtensionType(IntFlag):
|
||||
Plain = auto()
|
||||
@@ -338,9 +357,11 @@ def docx_to_text_and_images(
|
||||
of avoiding materializing the list of images in memory.
|
||||
The images list returned is empty in this case.
|
||||
"""
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
md = get_markitdown_converter()
|
||||
try:
|
||||
doc = md.convert(to_bytesio(file))
|
||||
doc = md.convert(
|
||||
to_bytesio(file), stream_info=StreamInfo(mimetype=WORD_PROCESSING_MIME_TYPE)
|
||||
)
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
@@ -372,9 +393,12 @@ def docx_to_text_and_images(
|
||||
|
||||
|
||||
def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
md = get_markitdown_converter()
|
||||
stream_info = StreamInfo(
|
||||
mimetype=PRESENTATION_MIME_TYPE, filename=file_name or None, extension=".pptx"
|
||||
)
|
||||
try:
|
||||
presentation = md.convert(to_bytesio(file))
|
||||
presentation = md.convert(to_bytesio(file), stream_info=stream_info)
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
@@ -388,23 +412,69 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
# TODO: switch back to this approach in a few months when markitdown
|
||||
# fixes their handling of excel files
|
||||
|
||||
# md = get_markitdown_converter()
|
||||
# stream_info = StreamInfo(
|
||||
# mimetype=SPREADSHEET_MIME_TYPE, filename=file_name or None, extension=".xlsx"
|
||||
# )
|
||||
# try:
|
||||
# workbook = md.convert(to_bytesio(file), stream_info=stream_info)
|
||||
# except (
|
||||
# BadZipFile,
|
||||
# ValueError,
|
||||
# FileConversionException,
|
||||
# UnsupportedFormatException,
|
||||
# ) as e:
|
||||
# error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
|
||||
# if file_name.startswith("~"):
|
||||
# logger.debug(error_str + " (this is expected for files with ~)")
|
||||
# else:
|
||||
# logger.warning(error_str)
|
||||
# return ""
|
||||
# return workbook.markdown
|
||||
try:
|
||||
workbook = md.convert(to_bytesio(file))
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
FileConversionException,
|
||||
UnsupportedFormatException,
|
||||
) as e:
|
||||
workbook = openpyxl.load_workbook(file, read_only=True)
|
||||
except BadZipFile as e:
|
||||
error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
|
||||
if file_name.startswith("~"):
|
||||
logger.debug(error_str + " (this is expected for files with ~)")
|
||||
else:
|
||||
logger.warning(error_str)
|
||||
return ""
|
||||
except Exception as e:
|
||||
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
|
||||
logger.error(
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
raise e
|
||||
|
||||
return workbook.markdown
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
rows = []
|
||||
num_empty_consecutive_rows = 0
|
||||
for row in sheet.iter_rows(min_row=1, values_only=True):
|
||||
row_str = ",".join(str(cell or "") for cell in row)
|
||||
|
||||
# Only add the row if there are any values in the cells
|
||||
if len(row_str) >= len(row):
|
||||
rows.append(row_str)
|
||||
num_empty_consecutive_rows = 0
|
||||
else:
|
||||
num_empty_consecutive_rows += 1
|
||||
|
||||
if num_empty_consecutive_rows > 100:
|
||||
# handle massive excel sheets with mostly empty cells
|
||||
logger.warning(
|
||||
f"Found {num_empty_consecutive_rows} empty rows in {file_name},"
|
||||
" skipping rest of file"
|
||||
)
|
||||
break
|
||||
sheet_str = "\n".join(rows)
|
||||
text_content.append(sheet_str)
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
|
||||
|
||||
def eml_to_text(file: IO[Any]) -> str:
|
||||
@@ -531,6 +601,23 @@ def extract_text_and_images(
|
||||
Primary new function for the updated connector.
|
||||
Returns structured extraction result with text content, embedded images, and metadata.
|
||||
"""
|
||||
res = _extract_text_and_images(
|
||||
file, file_name, pdf_pass, content_type, image_callback
|
||||
)
|
||||
# Clean up any temporary objects and force garbage collection
|
||||
unreachable = gc.collect()
|
||||
logger.info(f"Unreachable objects: {unreachable}")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _extract_text_and_images(
|
||||
file: IO[Any],
|
||||
file_name: str,
|
||||
pdf_pass: str | None = None,
|
||||
content_type: str | None = None,
|
||||
image_callback: Callable[[bytes, str], None] | None = None,
|
||||
) -> ExtractionResult:
|
||||
file.seek(0)
|
||||
|
||||
if get_unstructured_api_key():
|
||||
@@ -556,7 +643,6 @@ def extract_text_and_images(
|
||||
# Default processing
|
||||
try:
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
# docx example for embedded images
|
||||
if extension == ".docx":
|
||||
text_content, images = docx_to_text_and_images(
|
||||
|
||||
@@ -26,6 +26,7 @@ from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
@@ -42,7 +43,6 @@ from onyx.server.utils import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
@@ -50,6 +50,9 @@ logger = setup_logger()
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
|
||||
if BRAINTRUST_ENABLED:
|
||||
litellm.callbacks = ["braintrust"]
|
||||
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
|
||||
@@ -2,10 +2,10 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Literal
|
||||
|
||||
from braintrust import traced
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langsmith.run_helpers import traceable
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
@@ -53,6 +53,11 @@ def log_prompt(prompt: LanguageModelInput) -> None:
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
elif isinstance(msg, dict):
|
||||
log_msg = msg.get("content", "")
|
||||
if "files" in msg:
|
||||
log_msg = f"{log_msg}\nfiles: {msg['files']}"
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
@@ -87,7 +92,7 @@ class LLM(abc.ABC):
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
log_prompt(prompt)
|
||||
|
||||
@traceable(run_type="llm")
|
||||
@traced(name="invoke llm", type="llm")
|
||||
def invoke(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
@@ -121,7 +126,7 @@ class LLM(abc.ABC):
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@traceable(run_type="llm")
|
||||
@traced(name="stream llm", type="llm")
|
||||
def stream(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
|
||||
@@ -35,6 +35,7 @@ from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
@@ -51,6 +52,7 @@ from onyx.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
from onyx.db.engine.connection_warmup import warm_up_connections
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.evals.tracing import setup_braintrust
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.api_key.api import router as api_key_router
|
||||
from onyx.server.auth_check import check_router_auth
|
||||
@@ -252,6 +254,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
logger.notice("Generative AI Q&A disabled")
|
||||
|
||||
if BRAINTRUST_ENABLED:
|
||||
setup_braintrust()
|
||||
logger.notice("Braintrust tracing initialized")
|
||||
|
||||
# fill up Postgres connection pools
|
||||
await warm_up_connections()
|
||||
|
||||
|
||||
@@ -97,6 +97,15 @@ some sub-answers, but the question sent to the graph must be a question that can
|
||||
entity and relationship types and their attributes that will be communicated to you.
|
||||
"""
|
||||
|
||||
TOOL_DESCRIPTION[DRPath.PYTHON_TOOL] = (
|
||||
"""\
|
||||
This tool executes Python code within an isolated Code Interpreter environment. \
|
||||
Use it when the user explicitly requests code execution, data analysis, statistical computations, \
|
||||
or operations on uploaded files. The tool has access to any files the user supplied and can generate \
|
||||
new artifacts (for example, charts or datasets) that can be downloaded afterwards.\
|
||||
""".strip()
|
||||
)
|
||||
|
||||
TOOL_DESCRIPTION[
|
||||
DRPath.CLOSER
|
||||
] = f"""\
|
||||
@@ -948,6 +957,31 @@ ANSWER:
|
||||
)
|
||||
|
||||
|
||||
PYTHON_TOOL_USE_RESPONSE_PROMPT = PromptTemplate(
|
||||
f"""\
|
||||
Here is the base question that ultimately needs to be answered:
|
||||
{SEPARATOR_LINE}
|
||||
---base_question---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
You just used the `python` tool in an attempt to answer the base question. The respone from the \
|
||||
tool is:
|
||||
|
||||
{SEPARATOR_LINE}
|
||||
---tool_response---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
If any files were uploaded by the user, they are attached. If any files were generated by the \
|
||||
`python` tool, they are also attached, and prefixed with `generated_`.
|
||||
|
||||
|
||||
Please respond with a concise answer to the base question. If you believe additionl `python` \
|
||||
tool calls are necessary, please describe what HIGH LEVEL TASK we would like to accomplish next. \
|
||||
Do NOT describe any specific Python code that we would like to execute.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
TEST_INFO_COMPLETE_PROMPT = PromptTemplate(
|
||||
f"""\
|
||||
You are an expert at trying to determine whether \
|
||||
|
||||
@@ -10,6 +10,9 @@ from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.python_tool import (
|
||||
PythonTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
@@ -20,7 +23,12 @@ logger = setup_logger()
|
||||
|
||||
|
||||
BUILT_IN_TOOL_TYPES = Union[
|
||||
SearchTool, ImageGenerationTool, WebSearchTool, KnowledgeGraphTool, OktaProfileTool
|
||||
SearchTool,
|
||||
ImageGenerationTool,
|
||||
WebSearchTool,
|
||||
KnowledgeGraphTool,
|
||||
OktaProfileTool,
|
||||
PythonTool,
|
||||
]
|
||||
|
||||
# same as d09fc20a3c66_seed_builtin_tools.py
|
||||
@@ -30,6 +38,7 @@ BUILT_IN_TOOL_MAP: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {
|
||||
WebSearchTool.__name__: WebSearchTool,
|
||||
KnowledgeGraphTool.__name__: KnowledgeGraphTool,
|
||||
OktaProfileTool.__name__: OktaProfileTool,
|
||||
PythonTool.__name__: PythonTool,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@ from onyx.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from onyx.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from onyx.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@@ -55,6 +58,9 @@ from onyx.tools.tool_implementations.mcp.mcp_tool import MCPTool
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.python_tool import (
|
||||
PythonTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
@@ -280,6 +286,27 @@ def construct_tools(
|
||||
)
|
||||
]
|
||||
|
||||
# Handle Python Tool
|
||||
elif tool_cls.__name__ == PythonTool.__name__:
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
raise ValueError(
|
||||
"Code Interpreter base URL must be configured to use the Python tool"
|
||||
)
|
||||
|
||||
available_files = []
|
||||
if search_tool_config and search_tool_config.latest_query_files:
|
||||
available_files = search_tool_config.latest_query_files
|
||||
|
||||
tool_dict[db_tool_model.id] = [
|
||||
PythonTool(
|
||||
tool_id=db_tool_model.id,
|
||||
base_url=CODE_INTERPRETER_BASE_URL,
|
||||
default_timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
request_timeout_seconds=CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS,
|
||||
available_files=available_files,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle Internet Search Tool
|
||||
elif tool_cls.__name__ == WebSearchTool.__name__:
|
||||
if not internet_search_tool_config:
|
||||
|
||||
644
backend/onyx/tools/tool_implementations/python/python_tool.py
Normal file
644
backend/onyx/tools/tool_implementations/python/python_tool.py
Normal file
@@ -0,0 +1,644 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.tools.base_tool import BaseTool
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
PYTHON_TOOL_RESPONSE_ID = "python_tool_result"
|
||||
_EXECUTE_PATH = "/execute"
|
||||
_DEFAULT_RESULT_CHAR_LIMIT = 1200
|
||||
|
||||
|
||||
class PythonToolArgs(BaseModel):
|
||||
code: str = Field(
|
||||
...,
|
||||
description="Python source code to execute inside the Code Interpreter service",
|
||||
)
|
||||
stdin: str | None = Field(
|
||||
default=None,
|
||||
description="Optional standard input payload supplied to the process",
|
||||
)
|
||||
timeout_ms: int | None = Field(
|
||||
default=None,
|
||||
ge=1_000,
|
||||
description="Optional execution timeout override in milliseconds",
|
||||
)
|
||||
|
||||
|
||||
class ExecuteRequestFilePayload(BaseModel):
|
||||
path: str
|
||||
content_base64: str
|
||||
|
||||
|
||||
class ExecuteRequestPayload(BaseModel):
|
||||
code: str
|
||||
timeout_ms: int
|
||||
stdin: str | None = None
|
||||
files: list[ExecuteRequestFilePayload] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkspaceFilePayload(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
path: str
|
||||
kind: str = Field(default="file")
|
||||
content_base64: str | None = None
|
||||
mime_type: str | None = None
|
||||
size_bytes: int | None = Field(default=None, alias="size")
|
||||
|
||||
|
||||
class PythonToolArtifact(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
path: str
|
||||
kind: str
|
||||
file_id: str | None = None
|
||||
display_name: str | None = None
|
||||
mime_type: str | None = None
|
||||
size_bytes: int | None = None
|
||||
error: str | None = None
|
||||
chat_file_type: ChatFileType | None = None
|
||||
|
||||
|
||||
class PythonToolResult(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
stdout: str
|
||||
stderr: str | None = None
|
||||
exit_code: int | None = None
|
||||
execution_time_ms: int | None = None
|
||||
timeout_ms: int
|
||||
input_files: list[dict[str, Any]] = Field(default_factory=list)
|
||||
artifacts: list[PythonToolArtifact] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
_PYTHON_DESCRIPTION = """
|
||||
When you send a message containing Python code to python, it will be executed in an isolated environment.
|
||||
|
||||
The code you write will be placed into a file called `__main__.py` and run like `python __main__.py`. \
|
||||
All other files present in the conversation will also be present in this same directory. E.g. if the \
|
||||
user has uploaded three files, `analytics.csv`, `word.txt`, and `cat.png`, the directory structure will look like:
|
||||
|
||||
workspace/
|
||||
__main__.py
|
||||
analytics.csv
|
||||
word.txt
|
||||
cat.png
|
||||
|
||||
python will respond with the stdout of the `__main__.py` script as well as any new/edited files in the `workspace` directory. \
|
||||
That means, if you want to get access to some result of the script you can either: 1) `print(<result>)` or 2) save \
|
||||
a result to a file in the `workspace` directory.
|
||||
|
||||
Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
|
||||
|
||||
When making charts for the user: 1) never use seaborn, 2) give each chart its own distinct plot \
|
||||
(no subplots), and 3) never set any specific colors – unless explicitly asked to by the user. \
|
||||
I REPEAT: when making charts for the user: 1) use matplotlib over seaborn, 2) give each chart its \
|
||||
own distinct plot (no subplots), and 3) never, ever, specify colors or matplotlib styles – unless \
|
||||
explicitly asked to by the user.
|
||||
""".strip()
|
||||
|
||||
|
||||
class PythonTool(BaseTool):
|
||||
_NAME = "run_code_interpreter"
|
||||
_DESCRIPTION = _PYTHON_DESCRIPTION
|
||||
_DISPLAY_NAME = "Code Interpreter"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_id: int,
|
||||
base_url: str,
|
||||
default_timeout_ms: int,
|
||||
request_timeout_seconds: int,
|
||||
available_files: list[InMemoryChatFile] | None = None,
|
||||
) -> None:
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"Code Interpreter base URL must be configured to use the Python tool"
|
||||
)
|
||||
|
||||
self._id = tool_id
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._default_timeout_ms = default_timeout_ms
|
||||
self._request_timeout_seconds = request_timeout_seconds
|
||||
self._available_files: dict[str, InMemoryChatFile] = {
|
||||
chat_file.file_id: chat_file for chat_file in available_files or []
|
||||
}
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
return bool(CODE_INTERPRETER_BASE_URL)
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python source code to execute",
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
# Not supported for non-tool calling LLMs
|
||||
return None
|
||||
|
||||
def run(
|
||||
self,
|
||||
override_kwargs: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
raw_requested_files = kwargs.pop("files", None)
|
||||
|
||||
try:
|
||||
parsed_args = PythonToolArgs.model_validate(kwargs)
|
||||
except ValidationError as exc:
|
||||
logger.exception("Invalid arguments passed to PythonTool")
|
||||
raise ValueError("Invalid arguments supplied to Code Interpreter") from exc
|
||||
|
||||
timeout_ms = parsed_args.timeout_ms or self._default_timeout_ms
|
||||
override_files = self._coerce_override_files(override_kwargs)
|
||||
requested_files = self._load_requested_files(raw_requested_files)
|
||||
request_files, input_metadata = self._prepare_request_files(
|
||||
override_files + requested_files
|
||||
)
|
||||
|
||||
request_payload = ExecuteRequestPayload(
|
||||
code=parsed_args.code,
|
||||
timeout_ms=timeout_ms,
|
||||
stdin=parsed_args.stdin,
|
||||
files=request_files,
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self._base_url}{_EXECUTE_PATH}",
|
||||
json=request_payload.model_dump(exclude_none=True),
|
||||
timeout=self._request_timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as exc:
|
||||
logger.exception("Code Interpreter execution failed")
|
||||
raise ValueError(
|
||||
"Failed to reach the Code Interpreter service. Please try again later."
|
||||
) from exc
|
||||
|
||||
response_data = self._parse_response(response)
|
||||
artifacts = self._persist_artifacts(response_data.get("files", []))
|
||||
|
||||
python_result = PythonToolResult(
|
||||
stdout=self._ensure_text(response_data.get("stdout", "")),
|
||||
stderr=self._optional_text(response_data.get("stderr")),
|
||||
exit_code=self._safe_int(response_data.get("exit_code")),
|
||||
execution_time_ms=self._safe_int(
|
||||
response_data.get("execution_time_ms")
|
||||
or response_data.get("duration_ms")
|
||||
),
|
||||
timeout_ms=timeout_ms,
|
||||
input_files=input_metadata,
|
||||
artifacts=artifacts,
|
||||
metadata=self._extract_metadata(response_data.get("metadata")),
|
||||
)
|
||||
|
||||
yield ToolResponse(id=PYTHON_TOOL_RESPONSE_ID, response=python_result)
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
result = self._extract_result(args)
|
||||
sections: list[str] = []
|
||||
|
||||
if result.stdout:
|
||||
sections.append(self._format_section("stdout", result.stdout))
|
||||
if result.stderr:
|
||||
sections.append(self._format_section("stderr", result.stderr))
|
||||
|
||||
if result.artifacts:
|
||||
file_lines = []
|
||||
for artifact in result.artifacts:
|
||||
display_name = artifact.display_name or Path(artifact.path).name
|
||||
if artifact.error:
|
||||
file_lines.append(f"- {display_name}: {artifact.error}")
|
||||
elif artifact.file_id:
|
||||
file_lines.append(
|
||||
f"- {display_name} (file_id={artifact.file_id}, mime={artifact.mime_type or 'unknown'})"
|
||||
)
|
||||
else:
|
||||
file_lines.append(f"- {display_name}")
|
||||
sections.append("Generated files:\n" + "\n".join(file_lines))
|
||||
|
||||
return (
|
||||
"\n\n".join(sections) if sections else "Execution completed with no output."
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
return self._extract_result(args).model_dump()
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
tool_call_summary: ToolCallSummary,
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> AnswerPromptBuilder:
|
||||
updated_prompt_builder = super().build_next_prompt(
|
||||
prompt_builder, tool_call_summary, tool_responses, using_tool_calling_llm
|
||||
)
|
||||
|
||||
result = self._extract_result(tool_responses)
|
||||
if result.artifacts:
|
||||
existing_ids = {
|
||||
file.file_id for file in updated_prompt_builder.raw_user_uploaded_files
|
||||
}
|
||||
for artifact in result.artifacts:
|
||||
if not artifact.file_id:
|
||||
continue
|
||||
chat_file = self._available_files.get(artifact.file_id)
|
||||
if not chat_file or chat_file.file_id in existing_ids:
|
||||
continue
|
||||
if chat_file.file_type.is_text_file():
|
||||
updated_prompt_builder.raw_user_uploaded_files.append(chat_file)
|
||||
existing_ids.add(chat_file.file_id)
|
||||
return updated_prompt_builder
|
||||
|
||||
def _load_requested_files(self, raw_files: Any) -> list[InMemoryChatFile]:
|
||||
if not raw_files:
|
||||
return []
|
||||
|
||||
if not isinstance(raw_files, list):
|
||||
logger.warning(
|
||||
"Ignoring non-list 'files' argument passed to PythonTool: %r",
|
||||
type(raw_files),
|
||||
)
|
||||
return []
|
||||
|
||||
loaded_files: list[InMemoryChatFile] = []
|
||||
for descriptor in raw_files:
|
||||
file_id: str | None = None
|
||||
desired_name: str | None = None
|
||||
|
||||
if isinstance(descriptor, dict):
|
||||
raw_id = descriptor.get("id") or descriptor.get("file_id")
|
||||
if raw_id is not None:
|
||||
file_id = str(raw_id)
|
||||
desired_name = (
|
||||
descriptor.get("path")
|
||||
or descriptor.get("name")
|
||||
or descriptor.get("filename")
|
||||
)
|
||||
elif isinstance(descriptor, str):
|
||||
file_id = descriptor
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping unsupported file descriptor passed to PythonTool: %r",
|
||||
descriptor,
|
||||
)
|
||||
continue
|
||||
|
||||
if not file_id:
|
||||
logger.warning("Skipping file descriptor without an id: %r", descriptor)
|
||||
continue
|
||||
|
||||
chat_file = self._available_files.get(file_id)
|
||||
if not chat_file:
|
||||
chat_file = self._load_file_from_store(file_id)
|
||||
|
||||
if not chat_file:
|
||||
logger.warning("Requested file '%s' could not be found", file_id)
|
||||
continue
|
||||
|
||||
if desired_name and desired_name != chat_file.filename:
|
||||
chat_file = InMemoryChatFile(
|
||||
file_id=chat_file.file_id,
|
||||
content=chat_file.content,
|
||||
filename=desired_name,
|
||||
file_type=chat_file.file_type,
|
||||
)
|
||||
|
||||
loaded_files.append(chat_file)
|
||||
|
||||
return loaded_files
|
||||
|
||||
def _load_file_from_store(self, file_id: str) -> InMemoryChatFile | None:
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.exception("Failed to obtain file store instance: %s", exc)
|
||||
return None
|
||||
|
||||
try:
|
||||
file_record = file_store.read_file_record(file_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to read file record for '%s' from file store", file_id
|
||||
)
|
||||
return None
|
||||
|
||||
if not file_record:
|
||||
logger.warning("No file record found for requested file '%s'", file_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to read file content for '%s' from file store", file_id
|
||||
)
|
||||
return None
|
||||
|
||||
with closing(file_io):
|
||||
data = file_io.read() if hasattr(file_io, "read") else file_io
|
||||
|
||||
if not isinstance(data, (bytes, bytearray)):
|
||||
logger.warning("File store returned non-bytes payload for '%s'", file_id)
|
||||
return None
|
||||
|
||||
mime_type = (
|
||||
getattr(file_record, "file_type", None) or "application/octet-stream"
|
||||
)
|
||||
display_name = getattr(file_record, "display_name", None) or file_id
|
||||
chat_file_type = mime_type_to_chat_file_type(mime_type)
|
||||
|
||||
return InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=bytes(data),
|
||||
filename=display_name,
|
||||
file_type=chat_file_type,
|
||||
)
|
||||
|
||||
def _coerce_override_files(
|
||||
self, override_kwargs: dict[str, Any] | None
|
||||
) -> list[InMemoryChatFile]:
|
||||
if not override_kwargs:
|
||||
return []
|
||||
|
||||
raw_files = override_kwargs.get("files")
|
||||
if not raw_files:
|
||||
return []
|
||||
|
||||
coerced_files: list[InMemoryChatFile] = []
|
||||
for raw_file in raw_files:
|
||||
if isinstance(raw_file, InMemoryChatFile):
|
||||
coerced_files.append(raw_file)
|
||||
else:
|
||||
logger.warning(
|
||||
"Ignoring non InMemoryChatFile entry passed to PythonTool override: %r",
|
||||
type(raw_file),
|
||||
)
|
||||
return coerced_files
|
||||
|
||||
def _prepare_request_files(
|
||||
self, chat_files: list[InMemoryChatFile]
|
||||
) -> tuple[list[ExecuteRequestFilePayload], list[dict[str, Any]]]:
|
||||
request_files: list[ExecuteRequestFilePayload] = []
|
||||
input_metadata: list[dict[str, Any]] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for chat_file in chat_files:
|
||||
file_id = str(chat_file.file_id)
|
||||
if file_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(file_id)
|
||||
|
||||
self._available_files[file_id] = chat_file
|
||||
target_path = self._resolve_target_path(chat_file)
|
||||
|
||||
request_files.append(
|
||||
ExecuteRequestFilePayload(
|
||||
path=target_path,
|
||||
content_base64=base64.b64encode(chat_file.content).decode("utf-8"),
|
||||
)
|
||||
)
|
||||
|
||||
input_metadata.append(
|
||||
{
|
||||
"id": file_id,
|
||||
"path": target_path,
|
||||
"name": chat_file.filename or file_id,
|
||||
"chat_file_type": chat_file.file_type.value,
|
||||
}
|
||||
)
|
||||
|
||||
return request_files, input_metadata
|
||||
|
||||
def _persist_artifacts(self, files: list[Any]) -> list[PythonToolArtifact]:
|
||||
artifacts: list[PythonToolArtifact] = []
|
||||
if not files:
|
||||
return artifacts
|
||||
|
||||
file_store = get_default_file_store()
|
||||
|
||||
for raw_file in files:
|
||||
try:
|
||||
workspace_file = WorkspaceFilePayload.model_validate(raw_file)
|
||||
except ValidationError as exc:
|
||||
logger.warning("Skipping malformed workspace file entry: %s", exc)
|
||||
continue
|
||||
|
||||
artifact = PythonToolArtifact(
|
||||
path=workspace_file.path, kind=workspace_file.kind
|
||||
)
|
||||
if workspace_file.kind != "file" or not workspace_file.content_base64:
|
||||
artifacts.append(artifact)
|
||||
continue
|
||||
|
||||
try:
|
||||
binary = base64.b64decode(workspace_file.content_base64)
|
||||
except binascii.Error as exc:
|
||||
artifact.error = f"Failed to decode base64 content: {exc}"
|
||||
artifacts.append(artifact)
|
||||
continue
|
||||
|
||||
mime_type = self._infer_mime_type(workspace_file, binary)
|
||||
display_name = Path(workspace_file.path).name or workspace_file.path
|
||||
|
||||
try:
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(binary),
|
||||
display_name=display_name,
|
||||
file_origin=FileOrigin.GENERATED_REPORT,
|
||||
file_type=mime_type,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Failed to persist Code Interpreter artifact '%s'",
|
||||
workspace_file.path,
|
||||
)
|
||||
artifact.error = f"Failed to persist artifact: {exc}"
|
||||
artifacts.append(artifact)
|
||||
continue
|
||||
|
||||
chat_file_type = mime_type_to_chat_file_type(mime_type)
|
||||
chat_file = InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=binary,
|
||||
filename=display_name,
|
||||
file_type=chat_file_type,
|
||||
)
|
||||
self._available_files[file_id] = chat_file
|
||||
|
||||
artifact.file_id = file_id
|
||||
artifact.display_name = display_name
|
||||
artifact.mime_type = mime_type
|
||||
artifact.size_bytes = len(binary)
|
||||
artifact.chat_file_type = chat_file_type
|
||||
artifacts.append(artifact)
|
||||
|
||||
return artifacts
|
||||
|
||||
def _parse_response(self, response: requests.Response) -> dict[str, Any]:
|
||||
try:
|
||||
response_data = response.json()
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.exception("Code Interpreter returned invalid JSON")
|
||||
raise ValueError(
|
||||
"Code Interpreter returned an invalid JSON response"
|
||||
) from exc
|
||||
|
||||
error_payload = response_data.get("error")
|
||||
if error_payload:
|
||||
if isinstance(error_payload, dict):
|
||||
message = error_payload.get("message") or json.dumps(error_payload)
|
||||
else:
|
||||
message = str(error_payload)
|
||||
raise ValueError(f"Code Interpreter reported an error: {message}")
|
||||
|
||||
return response_data
|
||||
|
||||
def _extract_result(self, responses: tuple[ToolResponse, ...]) -> PythonToolResult:
|
||||
for response in responses:
|
||||
if response.id == PYTHON_TOOL_RESPONSE_ID:
|
||||
if isinstance(response.response, PythonToolResult):
|
||||
return response.response
|
||||
return PythonToolResult.model_validate(response.response)
|
||||
raise ValueError("No python tool result found in tool responses")
|
||||
|
||||
def _resolve_target_path(self, chat_file: InMemoryChatFile) -> str:
|
||||
if chat_file.filename:
|
||||
return Path(chat_file.filename).name
|
||||
return chat_file.file_id
|
||||
|
||||
def _infer_mime_type(
|
||||
self, workspace_file: WorkspaceFilePayload, binary: bytes
|
||||
) -> str:
|
||||
if workspace_file.mime_type:
|
||||
return workspace_file.mime_type
|
||||
|
||||
guess, _ = mimetypes.guess_type(workspace_file.path)
|
||||
if guess:
|
||||
return guess
|
||||
|
||||
if self._looks_like_text(binary):
|
||||
return "text/plain"
|
||||
|
||||
return "application/octet-stream"
|
||||
|
||||
@staticmethod
|
||||
def _ensure_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _optional_text(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text_value = str(value)
|
||||
return text_value if text_value else None
|
||||
|
||||
@staticmethod
|
||||
def _safe_int(value: Any) -> int | None:
|
||||
try:
|
||||
return int(value) if value is not None else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_metadata(value: Any) -> dict[str, Any] | None:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_text(payload: bytes, sample_size: int = 2048) -> bool:
|
||||
sample = payload[:sample_size]
|
||||
try:
|
||||
sample.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _format_section(title: str, body: str) -> str:
|
||||
truncated = (
|
||||
body
|
||||
if len(body) <= _DEFAULT_RESULT_CHAR_LIMIT
|
||||
else body[: _DEFAULT_RESULT_CHAR_LIMIT - 3] + "..."
|
||||
)
|
||||
return f"{title}:\n{truncated}"
|
||||
@@ -1,3 +1,16 @@
|
||||
PRESENTATION_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
|
||||
SPREADSHEET_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
WORD_PROCESSING_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
)
|
||||
PDF_MIME_TYPE = "application/pdf"
|
||||
|
||||
|
||||
class UploadMimeTypes:
|
||||
IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||
CSV_MIME_TYPES = {"text/csv"}
|
||||
@@ -13,10 +26,10 @@ class UploadMimeTypes:
|
||||
"application/x-yaml",
|
||||
}
|
||||
DOCUMENT_MIME_TYPES = {
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
PDF_MIME_TYPE,
|
||||
WORD_PROCESSING_MIME_TYPE,
|
||||
PRESENTATION_MIME_TYPE,
|
||||
SPREADSHEET_MIME_TYPE,
|
||||
"message/rfc822",
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
53
backend/onyx/utils/memory_logger.py
Normal file
53
backend/onyx/utils/memory_logger.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# # leaving this here for future mem debugging efforts
|
||||
# import os
|
||||
# from typing import Any
|
||||
|
||||
# import psutil
|
||||
# from pympler import asizeof
|
||||
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
# logger = setup_logger()
|
||||
|
||||
#
|
||||
# def log_memory_usage(
|
||||
# label: str,
|
||||
# specific_object: Any = None,
|
||||
# object_label: str = "",
|
||||
# ) -> None:
|
||||
# """Log current process memory usage and optionally the size of a specific object.
|
||||
|
||||
# Args:
|
||||
# label: A descriptive label for the current location/operation in code
|
||||
# specific_object: Optional object to measure the size of
|
||||
# object_label: Optional label describing the specific object
|
||||
# """
|
||||
# try:
|
||||
# # Get current process memory info
|
||||
# process = psutil.Process(os.getpid())
|
||||
# memory_info = process.memory_info()
|
||||
|
||||
# # Convert to MB for readability
|
||||
# rss_mb = memory_info.rss / (1024 * 1024)
|
||||
# vms_mb = memory_info.vms / (1024 * 1024)
|
||||
|
||||
# log_parts = [f"MEMORY[{label}]", f"RSS: {rss_mb:.2f}MB", f"VMS: {vms_mb:.2f}MB"]
|
||||
|
||||
# # Add object size if provided
|
||||
# if specific_object is not None:
|
||||
# try:
|
||||
# # recursively calculate the size of the object
|
||||
# obj_size = asizeof.asizeof(specific_object)
|
||||
# obj_size_mb = obj_size / (1024 * 1024)
|
||||
# obj_desc = f"[{object_label}]" if object_label else "[object]"
|
||||
# log_parts.append(f"OBJ{obj_desc}: {obj_size_mb:.2f}MB")
|
||||
# except Exception as e:
|
||||
# log_parts.append(f"OBJ_SIZE_ERROR: {str(e)}")
|
||||
|
||||
# logger.info(" | ".join(log_parts))
|
||||
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to log memory usage for {label}: {str(e)}")
|
||||
|
||||
# For example, use this like:
|
||||
# log_memory_usage("my_operation", my_large_object, "my_large_object")
|
||||
@@ -51,6 +51,7 @@ nltk==3.9.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.99.5
|
||||
openpyxl==3.1.5
|
||||
passlib==1.7.4
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
@@ -60,6 +61,7 @@ pyairtable==3.0.1
|
||||
pycryptodome==3.19.1
|
||||
pydantic==2.11.7
|
||||
PyGithub==2.5.0
|
||||
pympler==1.1
|
||||
python-dateutil==2.8.2
|
||||
python-gitlab==5.6.0
|
||||
python-pptx==0.6.23
|
||||
@@ -83,6 +85,7 @@ supervisor==4.2.5
|
||||
RapidFuzz==3.13.0
|
||||
tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
types-openpyxl==3.1.5.20250919
|
||||
unstructured==0.15.1
|
||||
unstructured-client==0.25.4
|
||||
uvicorn==0.35.0
|
||||
|
||||
@@ -26,6 +26,7 @@ def mock_zendesk_client() -> MagicMock:
|
||||
mock = MagicMock(spec=ZendeskClient)
|
||||
mock.base_url = "https://test.zendesk.com/api/v2"
|
||||
mock.auth = ("test@example.com/token", "test_token")
|
||||
mock.make_request = MagicMock()
|
||||
return mock
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeTime:
|
||||
"""A controllable time module replacement.
|
||||
|
||||
- monotonic(): returns an internal counter (seconds)
|
||||
- sleep(x): advances the internal counter by x seconds
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._t = 0.0
|
||||
|
||||
def monotonic(self) -> float: # type: ignore[override]
|
||||
return self._t
|
||||
|
||||
def sleep(self, seconds: float) -> None: # type: ignore[override]
|
||||
# advance time without real waiting
|
||||
self._t += float(seconds)
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, json_payload: Dict[str, Any], status_code: int = 200) -> None:
|
||||
self._json = json_payload
|
||||
self.status_code = status_code
|
||||
self.headers: Dict[str, str] = {}
|
||||
|
||||
def json(self) -> Dict[str, Any]:
|
||||
return self._json
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
# simulate OK
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_zendesk_client_per_minute_rate_limiting(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Import here to allow monkeypatching modules safely
|
||||
from onyx.connectors.zendesk.connector import ZendeskClient
|
||||
import onyx.connectors.cross_connector_utils.rate_limit_wrapper as rlw
|
||||
import onyx.connectors.zendesk.connector as zendesk_mod
|
||||
|
||||
fake_time = _FakeTime()
|
||||
|
||||
# Patch time in both the rate limit wrapper and the zendesk connector module
|
||||
monkeypatch.setattr(rlw, "time", fake_time, raising=True)
|
||||
monkeypatch.setattr(zendesk_mod, "time", fake_time, raising=True)
|
||||
|
||||
# Stub out requests.get to avoid network and return a minimal valid payload
|
||||
calls: list[str] = []
|
||||
|
||||
def _fake_get(url: str, auth: Any, params: Dict[str, Any]) -> _FakeResponse:
|
||||
calls.append(url)
|
||||
# minimal Zendesk list response (articles path)
|
||||
return _FakeResponse({"articles": [], "meta": {"has_more": False}})
|
||||
|
||||
monkeypatch.setattr(
|
||||
zendesk_mod, "requests", types.SimpleNamespace(get=_fake_get), raising=True
|
||||
)
|
||||
|
||||
# Build client with a small limit: 2 calls per 60 seconds
|
||||
client = ZendeskClient("subd", "e", "t", calls_per_minute=2)
|
||||
|
||||
# Make three calls in quick succession. The third should be rate limited
|
||||
client.make_request("help_center/articles", {"page[size]": 1})
|
||||
client.make_request("help_center/articles", {"page[size]": 1})
|
||||
|
||||
# At this point we've used up the 2 allowed calls within the 60s window
|
||||
# The next call should trigger sleeps with exponential backoff until >60s elapsed
|
||||
client.make_request("help_center/articles", {"page[size]": 1})
|
||||
|
||||
# Ensure we did not actually wait in real time but logically advanced beyond a minute
|
||||
assert fake_time.monotonic() >= 60
|
||||
# Ensure the HTTP function was invoked three times
|
||||
assert len(calls) == 3
|
||||
@@ -88,6 +88,7 @@ import { FilePickerModal } from "@/app/chat/my-documents/components/FilePicker";
|
||||
import { useDocumentsContext } from "@/app/chat/my-documents/DocumentsContext";
|
||||
|
||||
import {
|
||||
PYTHON_TOOL_ID,
|
||||
IMAGE_GENERATION_TOOL_ID,
|
||||
SEARCH_TOOL_ID,
|
||||
WEB_SEARCH_TOOL_ID,
|
||||
@@ -111,6 +112,10 @@ function findWebSearchTool(tools: ToolSnapshot[]) {
|
||||
return tools.find((tool) => tool.in_code_tool_id === WEB_SEARCH_TOOL_ID);
|
||||
}
|
||||
|
||||
function findPythonTool(tools: ToolSnapshot[]) {
|
||||
return tools.find((tool) => tool.in_code_tool_id === PYTHON_TOOL_ID);
|
||||
}
|
||||
|
||||
function SubLabel({ children }: { children: string | JSX.Element }) {
|
||||
return (
|
||||
<div
|
||||
@@ -213,13 +218,15 @@ export function AssistantEditor({
|
||||
const searchTool = findSearchTool(tools);
|
||||
const imageGenerationTool = findImageGenerationTool(tools);
|
||||
const webSearchTool = findWebSearchTool(tools);
|
||||
const pythonTool = findPythonTool(tools);
|
||||
|
||||
// Separate MCP tools from regular custom tools
|
||||
const allCustomTools = tools.filter(
|
||||
(tool) =>
|
||||
tool.in_code_tool_id !== searchTool?.in_code_tool_id &&
|
||||
tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id &&
|
||||
tool.in_code_tool_id !== webSearchTool?.in_code_tool_id
|
||||
tool.in_code_tool_id !== webSearchTool?.in_code_tool_id &&
|
||||
tool.in_code_tool_id !== pythonTool?.in_code_tool_id
|
||||
);
|
||||
|
||||
const mcpTools = allCustomTools.filter((tool) => tool.mcp_server_id);
|
||||
@@ -290,6 +297,7 @@ export function AssistantEditor({
|
||||
...mcpTools, // Include MCP tools for form logic
|
||||
...(searchTool ? [searchTool] : []),
|
||||
...(imageGenerationTool ? [imageGenerationTool] : []),
|
||||
...(pythonTool ? [pythonTool] : []),
|
||||
...(webSearchTool ? [webSearchTool] : []),
|
||||
];
|
||||
const enabledToolsMap: { [key: number]: boolean } = {};
|
||||
@@ -1211,6 +1219,16 @@ export function AssistantEditor({
|
||||
</>
|
||||
)}
|
||||
|
||||
{pythonTool && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
name={`enabled_tools_map.${pythonTool.id}`}
|
||||
label={pythonTool.display_name}
|
||||
subtext="Execute Python code against staged files using the Code Interpreter sandbox"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{webSearchTool && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
|
||||
@@ -272,6 +272,9 @@ function ToolToggle({
|
||||
if (tool.in_code_tool_id === "ImageGenerationTool") {
|
||||
return "Add an OpenAI LLM provider with an API key under Admin → Configuration → LLM.";
|
||||
}
|
||||
if (tool.in_code_tool_id === "PythonTool") {
|
||||
return "Set CODE_INTERPRETER_BASE_URL on the server and restart to enable Code Interpreter.";
|
||||
}
|
||||
return "Not configured.";
|
||||
})();
|
||||
return (
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import { FiImage, FiSearch } from "react-icons/fi";
|
||||
import { FiImage, FiSearch, FiTerminal } from "react-icons/fi";
|
||||
import { Persona } from "../admin/assistants/interfaces";
|
||||
import { SEARCH_TOOL_ID } from "../chat/components/tools/constants";
|
||||
import {
|
||||
PYTHON_TOOL_ID,
|
||||
SEARCH_TOOL_ID,
|
||||
} from "../chat/components/tools/constants";
|
||||
|
||||
export function AssistantTools({
|
||||
assistant,
|
||||
@@ -69,6 +72,26 @@ export function AssistantTools({
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
} else if (tool.name === PYTHON_TOOL_ID) {
|
||||
return (
|
||||
<div
|
||||
key={ind}
|
||||
className={`
|
||||
px-1.5
|
||||
py-1
|
||||
rounded-lg
|
||||
border
|
||||
border-border
|
||||
w-fit
|
||||
flex
|
||||
${list ? "bg-background-125" : "bg-background-100"}`}
|
||||
>
|
||||
<div className="flex items-center gap-x-1">
|
||||
<FiTerminal className="ml-1 my-auto h-3 w-3" />
|
||||
{tool.display_name || "Code Interpreter"}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<div
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
export const SEARCH_TOOL_NAME = "run_search";
|
||||
export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search";
|
||||
export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation";
|
||||
export const PYTHON_TOOL_NAME = "run_code_interpreter";
|
||||
|
||||
// In-code tool IDs that also correspond to the tool's name when associated with a persona
|
||||
export const SEARCH_TOOL_ID = "SearchTool";
|
||||
export const IMAGE_GENERATION_TOOL_ID = "ImageGenerationTool";
|
||||
export const WEB_SEARCH_TOOL_ID = "WebSearchTool";
|
||||
export const PYTHON_TOOL_ID = "PythonTool";
|
||||
|
||||
@@ -32,6 +32,14 @@ const isImageGenerationTool = (tool: ToolSnapshot): boolean => {
|
||||
);
|
||||
};
|
||||
|
||||
const isPythonTool = (tool: ToolSnapshot): boolean => {
|
||||
return (
|
||||
tool.in_code_tool_id === "PythonTool" ||
|
||||
tool.in_code_tool_id === "AnalysisTool" ||
|
||||
tool.display_name?.toLowerCase().includes("code interpreter")
|
||||
);
|
||||
};
|
||||
|
||||
const isKnowledgeGraphTool = (tool: ToolSnapshot): boolean => {
|
||||
return (
|
||||
tool.in_code_tool_id === "KnowledgeGraphTool" ||
|
||||
@@ -55,6 +63,8 @@ export function getIconForAction(
|
||||
return GlobeIcon;
|
||||
} else if (isImageGenerationTool(action)) {
|
||||
return ImageIcon;
|
||||
} else if (isPythonTool(action)) {
|
||||
return CpuIcon;
|
||||
} else if (isKnowledgeGraphTool(action)) {
|
||||
return DatabaseIcon;
|
||||
} else if (isOktaProfileTool(action)) {
|
||||
|
||||
@@ -1071,7 +1071,16 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
|
||||
default: "articles",
|
||||
},
|
||||
],
|
||||
advanced_values: [],
|
||||
advanced_values: [
|
||||
{
|
||||
type: "number",
|
||||
label: "API Calls per Minute",
|
||||
name: "calls_per_minute",
|
||||
optional: true,
|
||||
description:
|
||||
"Restricts how many Zendesk API calls this connector can make per minute (applies only to this connector). See defaults: https://developer.zendesk.com/api-reference/introduction/rate-limits/",
|
||||
},
|
||||
],
|
||||
},
|
||||
linear: {
|
||||
description: "Configure Linear connector",
|
||||
@@ -1770,7 +1779,10 @@ export interface XenforoConfig {
|
||||
base_url: string;
|
||||
}
|
||||
|
||||
export interface ZendeskConfig {}
|
||||
export interface ZendeskConfig {
|
||||
content_type?: "articles" | "tickets";
|
||||
calls_per_minute?: number;
|
||||
}
|
||||
|
||||
export interface DropboxConfig {}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import { dragElementAbove, dragElementBelow } from "../utils/dragUtils";
|
||||
import { loginAsRandomUser } from "../utils/auth";
|
||||
import { createAssistant, pinAssistantByName } from "../utils/assistantUtils";
|
||||
|
||||
test("Assistant Drag and Drop", async ({ page }, testInfo) => {
|
||||
test("Assistant Drag and Drop", async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAsRandomUser(page);
|
||||
|
||||
@@ -60,14 +60,6 @@ test("Assistant Drag and Drop", async ({ page }, testInfo) => {
|
||||
|
||||
// Get the initial order
|
||||
const initialOrder = await getAssistantOrder();
|
||||
await testInfo.attach("assistant-order-initial", {
|
||||
body: Buffer.from(JSON.stringify(initialOrder, null, 2), "utf-8"),
|
||||
contentType: "application/json",
|
||||
});
|
||||
await testInfo.attach("screenshot-initial-order", {
|
||||
body: await page.screenshot({ fullPage: true }),
|
||||
contentType: "image/png",
|
||||
});
|
||||
|
||||
// Drag second assistant above first
|
||||
const secondAssistant = page.locator('[data-testid^="assistant-["]').nth(1);
|
||||
@@ -77,19 +69,10 @@ test("Assistant Drag and Drop", async ({ page }, testInfo) => {
|
||||
|
||||
// Check new order
|
||||
const orderAfterDragUp = await getAssistantOrder();
|
||||
await testInfo.attach("assistant-order-after-drag-up", {
|
||||
body: Buffer.from(JSON.stringify(orderAfterDragUp, null, 2), "utf-8"),
|
||||
contentType: "application/json",
|
||||
});
|
||||
await testInfo.attach("screenshot-after-drag-up", {
|
||||
body: await page.screenshot({ fullPage: true }),
|
||||
contentType: "image/png",
|
||||
});
|
||||
expect(orderAfterDragUp[0]).toBe(initialOrder[1]);
|
||||
expect(orderAfterDragUp[1]).toBe(initialOrder[0]);
|
||||
|
||||
// Drag last assistant to second position
|
||||
console.log("Dragging last assistant to second position");
|
||||
const assistants = page.locator('[data-testid^="assistant-["]');
|
||||
const lastIndex = (await assistants.count()) - 1;
|
||||
const lastAssistant = assistants.nth(lastIndex);
|
||||
@@ -100,26 +83,10 @@ test("Assistant Drag and Drop", async ({ page }, testInfo) => {
|
||||
|
||||
// Check new order
|
||||
const orderAfterDragDown = await getAssistantOrder();
|
||||
await testInfo.attach("assistant-order-after-drag-down", {
|
||||
body: Buffer.from(JSON.stringify(orderAfterDragDown, null, 2), "utf-8"),
|
||||
contentType: "application/json",
|
||||
});
|
||||
await testInfo.attach("screenshot-after-drag-down", {
|
||||
body: await page.screenshot({ fullPage: true }),
|
||||
contentType: "image/png",
|
||||
});
|
||||
expect(orderAfterDragDown[1]).toBe(initialOrder[lastIndex]);
|
||||
|
||||
// Refresh and verify order
|
||||
await page.reload();
|
||||
const orderAfterRefresh = await getAssistantOrder();
|
||||
await testInfo.attach("assistant-order-after-refresh", {
|
||||
body: Buffer.from(JSON.stringify(orderAfterRefresh, null, 2), "utf-8"),
|
||||
contentType: "application/json",
|
||||
});
|
||||
await testInfo.attach("screenshot-after-refresh", {
|
||||
body: await page.screenshot({ fullPage: true }),
|
||||
contentType: "image/png",
|
||||
});
|
||||
expect(orderAfterRefresh).toEqual(orderAfterDragDown);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user