mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
12 Commits
debug-shar
...
add-code-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68a484ae73 | ||
|
|
f4d135d710 | ||
|
|
6094f70ac8 | ||
|
|
a90e58b39b | ||
|
|
e82e3141ed | ||
|
|
f8e9060bab | ||
|
|
24831fa1a1 | ||
|
|
f6a0e69b2a | ||
|
|
0394eaea7f | ||
|
|
898b8c316e | ||
|
|
4b0c6d1e54 | ||
|
|
da7dc33afa |
@@ -23,6 +23,7 @@ env:
|
||||
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
|
||||
2
.github/workflows/pr-python-tests.yml
vendored
2
.github/workflows/pr-python-tests.yml
vendored
@@ -31,12 +31,14 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
|
||||
73
backend/alembic/versions/1c3f8a7b5d4e_add_python_tool.py
Normal file
73
backend/alembic/versions/1c3f8a7b5d4e_add_python_tool.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""add_python_tool
|
||||
|
||||
Revision ID: 1c3f8a7b5d4e
|
||||
Revises: 505c488f6662
|
||||
Create Date: 2025-02-14 00:00:00
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1c3f8a7b5d4e"
|
||||
down_revision = "505c488f6662"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
PYTHON_TOOL = {
|
||||
"name": "PythonTool",
|
||||
"display_name": "Code Interpreter",
|
||||
"description": (
|
||||
"The Code Interpreter Action lets assistants execute Python in an isolated runtime. "
|
||||
"It can process staged files, read and write artifacts, stream stdout and stderr, "
|
||||
"and return generated outputs for the chat session."
|
||||
),
|
||||
"in_code_tool_id": "PythonTool",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
try:
|
||||
existing = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
PYTHON_TOOL,
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
PYTHON_TOOL,
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
PYTHON_TOOL,
|
||||
)
|
||||
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Do not delete the tool entry on downgrade; leaving it is safe and keeps migrations idempotent.
|
||||
pass
|
||||
@@ -39,6 +39,7 @@ def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
DRPath.WEB_SEARCH,
|
||||
DRPath.KNOWLEDGE_GRAPH,
|
||||
DRPath.IMAGE_GENERATION,
|
||||
DRPath.PYTHON_TOOL,
|
||||
)
|
||||
and len(state.query_list) == 0
|
||||
):
|
||||
|
||||
@@ -21,6 +21,7 @@ AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
DRPath.WEB_SEARCH: 1.5,
|
||||
DRPath.IMAGE_GENERATION: 3.0,
|
||||
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
DRPath.PYTHON_TOOL: 2.0,
|
||||
DRPath.CLOSER: 0.0,
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class DRPath(str, Enum):
|
||||
WEB_SEARCH = "Web Search"
|
||||
IMAGE_GENERATION = "Image Generation"
|
||||
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
|
||||
PYTHON_TOOL = "Python"
|
||||
CLOSER = "Closer"
|
||||
LOGGER = "Logger"
|
||||
END = "End"
|
||||
|
||||
@@ -26,6 +26,9 @@ from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
dr_kg_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_graph_builder import (
|
||||
dr_python_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
|
||||
dr_ws_graph_builder,
|
||||
)
|
||||
@@ -58,12 +61,15 @@ def dr_graph_builder() -> StateGraph:
|
||||
image_generation_graph = dr_image_generation_graph_builder().compile()
|
||||
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
|
||||
|
||||
python_tool_graph = dr_python_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.PYTHON_TOOL, python_tool_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
graph.add_node(DRPath.CLOSER, closer)
|
||||
graph.add_node(DRPath.LOGGER, logging)
|
||||
|
||||
@@ -81,6 +87,7 @@ def dr_graph_builder() -> StateGraph:
|
||||
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.PYTHON_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
|
||||
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
|
||||
|
||||
@@ -69,6 +69,7 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
@@ -134,6 +135,9 @@ def _get_available_tools(
|
||||
continue
|
||||
llm_path = DRPath.KNOWLEDGE_GRAPH.value
|
||||
path = DRPath.KNOWLEDGE_GRAPH
|
||||
elif isinstance(tool, PythonTool):
|
||||
llm_path = DRPath.PYTHON_TOOL.value
|
||||
path = DRPath.PYTHON_TOOL
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
llm_path = DRPath.IMAGE_GENERATION.value
|
||||
path = DRPath.IMAGE_GENERATION
|
||||
@@ -778,6 +782,6 @@ def clarifier(
|
||||
active_source_types_descriptions="\n".join(active_source_types_descriptions),
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
uploaded_test_context=uploaded_text_context,
|
||||
uploaded_text_context=uploaded_text_context,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
)
|
||||
|
||||
@@ -140,7 +140,7 @@ def orchestrator(
|
||||
|
||||
available_tools = state.available_tools or {}
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
uploaded_context = state.uploaded_text_context or ""
|
||||
uploaded_image_context = state.uploaded_image_context or []
|
||||
|
||||
questions = [
|
||||
|
||||
@@ -228,7 +228,7 @@ def closer(
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
uploaded_context = state.uploaded_text_context or ""
|
||||
|
||||
clarification = state.clarification
|
||||
prompt_question = get_prompt_question(base_question, clarification)
|
||||
|
||||
@@ -46,7 +46,7 @@ class OrchestrationSetup(OrchestrationUpdate):
|
||||
active_source_types_descriptions: str | None = None
|
||||
assistant_system_prompt: str | None = None
|
||||
assistant_task_prompt: str | None = None
|
||||
uploaded_test_context: str | None = None
|
||||
uploaded_text_context: str | None = None
|
||||
uploaded_image_context: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Python Tool sub-agent for deep research."""
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def python_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""Log the beginning of a Python Tool branch."""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(
|
||||
f"Python Tool branch start for iteration {iteration_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="python_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,257 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import PYTHON_TOOL_USE_RESPONSE_PROMPT
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonToolResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _serialize_chat_files(chat_files: list[InMemoryChatFile]) -> list[dict[str, Any]]:
|
||||
serialized_files: list[dict[str, Any]] = []
|
||||
for chat_file in chat_files:
|
||||
file_payload: dict[str, Any] = {
|
||||
"id": str(chat_file.file_id),
|
||||
"name": chat_file.filename,
|
||||
"type": chat_file.file_type.value,
|
||||
}
|
||||
if chat_file.file_type == ChatFileType.IMAGE:
|
||||
file_payload["content"] = chat_file.to_base64()
|
||||
file_payload["is_base64"] = True
|
||||
elif chat_file.file_type.is_text_file():
|
||||
file_payload["content"] = chat_file.content.decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
file_payload["is_base64"] = False
|
||||
else:
|
||||
file_payload["content"] = base64.b64encode(chat_file.content).decode(
|
||||
"utf-8"
|
||||
)
|
||||
file_payload["is_base64"] = True
|
||||
serialized_files.append(file_payload)
|
||||
|
||||
return serialized_files
|
||||
|
||||
|
||||
def python_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""Execute the Python Tool with any files supplied by the user."""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
tool_key = state.tools_used[-1]
|
||||
python_tool_info = state.available_tools[tool_key]
|
||||
python_tool = cast(PythonTool | None, python_tool_info.tool_object)
|
||||
|
||||
if python_tool is None:
|
||||
raise ValueError("python_tool is not set")
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
files = graph_config.inputs.files
|
||||
|
||||
logger.debug(
|
||||
"Tool call start for %s %s.%s at %s",
|
||||
python_tool.llm_name,
|
||||
iteration_nr,
|
||||
parallelization_nr,
|
||||
datetime.now(),
|
||||
)
|
||||
|
||||
tool_args: dict[str, Any] | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_description=python_tool_info.description,
|
||||
)
|
||||
|
||||
content_with_files = build_content_with_imgs(
|
||||
message=tool_use_prompt,
|
||||
files=files,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
tool_prompt_message: dict[str, Any] = {
|
||||
"role": "user",
|
||||
"content": content_with_files,
|
||||
}
|
||||
if files:
|
||||
tool_prompt_message["files"] = _serialize_chat_files(files)
|
||||
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke(
|
||||
[tool_prompt_message],
|
||||
tools=[python_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
|
||||
if isinstance(tool_calling_msg, AIMessage) and tool_calling_msg.tool_calls:
|
||||
tool_args = tool_calling_msg.tool_calls[0].get("args")
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call for Python Tool")
|
||||
|
||||
if tool_args is None:
|
||||
tool_args = python_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
|
||||
if "files" in tool_args:
|
||||
tool_args = {key: value for key, value in tool_args.items() if key != "files"}
|
||||
|
||||
override_kwargs = {"files": files or []}
|
||||
|
||||
tool_responses = list(python_tool.run(override_kwargs=override_kwargs, **tool_args))
|
||||
|
||||
python_tool_result: PythonToolResult | None = None
|
||||
for response in tool_responses:
|
||||
if isinstance(response.response, PythonToolResult):
|
||||
python_tool_result = response.response
|
||||
break
|
||||
|
||||
if python_tool_result is None:
|
||||
raise ValueError("Python tool did not return a valid result")
|
||||
|
||||
final_result = python_tool.final_result(*tool_responses)
|
||||
tool_result_str = json.dumps(final_result, ensure_ascii=False)
|
||||
|
||||
tool_summary_prompt = PYTHON_TOOL_USE_RESPONSE_PROMPT.build(
|
||||
base_question=base_question,
|
||||
tool_response=tool_result_str,
|
||||
)
|
||||
|
||||
initial_files = list(files or [])
|
||||
generated_files: list[InMemoryChatFile] = []
|
||||
for artifact in python_tool_result.artifacts:
|
||||
if not artifact.file_id:
|
||||
continue
|
||||
|
||||
chat_file = python_tool._available_files.get(artifact.file_id)
|
||||
if not chat_file:
|
||||
logger.warning(
|
||||
"Generated artifact with id %s not found in available files",
|
||||
artifact.file_id,
|
||||
)
|
||||
continue
|
||||
|
||||
filename = (
|
||||
chat_file.filename
|
||||
or artifact.display_name
|
||||
or artifact.path
|
||||
or str(artifact.file_id)
|
||||
)
|
||||
filename = Path(filename).name or str(artifact.file_id)
|
||||
if not filename.startswith("generated_"):
|
||||
filename = f"generated_{filename}"
|
||||
|
||||
generated_files.append(
|
||||
InMemoryChatFile(
|
||||
file_id=chat_file.file_id,
|
||||
content=chat_file.content,
|
||||
file_type=chat_file.file_type,
|
||||
filename=filename,
|
||||
)
|
||||
)
|
||||
|
||||
summary_files = initial_files + generated_files
|
||||
summary_content = build_content_with_imgs(
|
||||
message=tool_summary_prompt,
|
||||
files=summary_files,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
summary_message: dict[str, Any] = {
|
||||
"role": "user",
|
||||
"content": summary_content,
|
||||
}
|
||||
if summary_files:
|
||||
summary_message["files"] = _serialize_chat_files(summary_files)
|
||||
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke(
|
||||
[summary_message],
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
).content
|
||||
).strip()
|
||||
|
||||
artifact_file_ids = [
|
||||
artifact.file_id
|
||||
for artifact in python_tool_result.artifacts
|
||||
if artifact.file_id
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
"Tool call end for %s %s.%s at %s",
|
||||
python_tool.llm_name,
|
||||
iteration_nr,
|
||||
parallelization_nr,
|
||||
datetime.now(),
|
||||
)
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=python_tool.llm_name,
|
||||
tool_id=python_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
response_type="json",
|
||||
data=final_result,
|
||||
file_ids=artifact_file_ids or None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="python_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,76 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def python_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""Stream the Python Tool result back to the client."""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
current_step_nr = state.current_step_nr
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
for new_update in new_updates:
|
||||
if not new_update.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolStart(
|
||||
tool_name=new_update.tool,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolDelta(
|
||||
tool_name=new_update.tool,
|
||||
response_type=new_update.response_type,
|
||||
data=new_update.data,
|
||||
file_ids=new_update.file_ids,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="python_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
"""Forward the current query to the Python Tool executor."""
|
||||
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(state.query_list[:1])
|
||||
]
|
||||
@@ -0,0 +1,38 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_1_branch import (
|
||||
python_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_2_act import (
|
||||
python_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_3_reduce import (
|
||||
python_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.python_tool.dr_python_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_python_tool_graph_builder() -> StateGraph:
|
||||
"""LangGraph graph builder for the Python Tool sub-agent."""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
graph.add_node("branch", python_tool_branch)
|
||||
graph.add_node("act", python_tool_act)
|
||||
graph.add_node("reducer", python_tool_reducer)
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -414,8 +414,14 @@ def monitor_document_set_taskset(
|
||||
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if document_set:
|
||||
if not document_set.connector_credential_pairs:
|
||||
# if there are no connectors, then delete the document set.
|
||||
has_connector_pairs = bool(document_set.connector_credential_pairs)
|
||||
# Federated connectors should keep a document set alive even without cc pairs.
|
||||
has_federated_connectors = bool(
|
||||
getattr(document_set, "federated_connectors", [])
|
||||
)
|
||||
|
||||
if not has_connector_pairs and not has_federated_connectors:
|
||||
# If there are no connectors of any kind, delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set: document_set={document_set_id}"
|
||||
|
||||
@@ -758,6 +758,24 @@ AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")
|
||||
# configurable image model
|
||||
IMAGE_MODEL_NAME = os.environ.get("IMAGE_MODEL_NAME", "gpt-image-1")
|
||||
|
||||
CODE_INTERPRETER_BASE_URL = os.environ.get("CODE_INTERPRETER_BASE_URL")
|
||||
_CODE_INTERPRETER_DEFAULT_TIMEOUT_MS_RAW = os.environ.get(
|
||||
"CODE_INTERPRETER_DEFAULT_TIMEOUT_MS"
|
||||
)
|
||||
CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = (
|
||||
int(_CODE_INTERPRETER_DEFAULT_TIMEOUT_MS_RAW)
|
||||
if _CODE_INTERPRETER_DEFAULT_TIMEOUT_MS_RAW
|
||||
else 30_000
|
||||
)
|
||||
_CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS_RAW = os.environ.get(
|
||||
"CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS"
|
||||
)
|
||||
CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS = (
|
||||
int(_CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS_RAW)
|
||||
if _CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS_RAW
|
||||
else 30
|
||||
)
|
||||
|
||||
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
|
||||
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -1029,7 +1029,7 @@ class SharepointConnector(
|
||||
|
||||
# Filter pages based on time window if specified
|
||||
if start is not None or end is not None:
|
||||
filtered_pages = []
|
||||
filtered_pages: list[dict[str, Any]] = []
|
||||
for page in all_pages:
|
||||
page_modified = page.get("lastModifiedDateTime")
|
||||
if page_modified:
|
||||
|
||||
@@ -761,7 +761,7 @@ class SlackConnector(
|
||||
Step 2: Loop through each channel. For each channel:
|
||||
Step 2.1: Get messages within the time range.
|
||||
Step 2.2: Process messages in parallel, yield back docs.
|
||||
Step 2.3: Update checkpoint with new_latest, seen_thread_ts, and current_channel.
|
||||
Step 2.3: Update checkpoint with new_oldest, seen_thread_ts, and current_channel.
|
||||
Slack returns messages from newest to oldest, so we need to keep track of
|
||||
the latest message we've seen in each channel.
|
||||
Step 2.4: If there are no more messages in the channel, switch the current
|
||||
@@ -837,7 +837,8 @@ class SlackConnector(
|
||||
|
||||
channel_message_ts = checkpoint.channel_completion_map.get(channel_id)
|
||||
if channel_message_ts:
|
||||
latest = channel_message_ts
|
||||
# Set oldest to the checkpoint timestamp to resume from where we left off
|
||||
oldest = channel_message_ts
|
||||
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
@@ -855,7 +856,8 @@ class SlackConnector(
|
||||
f"{latest=}"
|
||||
)
|
||||
|
||||
new_latest = message_batch[-1]["ts"] if message_batch else latest
|
||||
# message_batch[0] is the newest message (Slack returns newest to oldest)
|
||||
new_oldest = message_batch[0]["ts"] if message_batch else latest
|
||||
|
||||
num_threads_start = len(seen_thread_ts)
|
||||
|
||||
@@ -906,15 +908,14 @@ class SlackConnector(
|
||||
num_threads_processed = len(seen_thread_ts) - num_threads_start
|
||||
|
||||
# calculate a percentage progress for the current channel by determining
|
||||
# our viable range start and end, and the latest timestamp we are querying
|
||||
# up to
|
||||
new_latest_seconds_epoch = SecondsSinceUnixEpoch(new_latest)
|
||||
if new_latest_seconds_epoch > end:
|
||||
# how much of the time range we've processed so far
|
||||
new_oldest_seconds_epoch = SecondsSinceUnixEpoch(new_oldest)
|
||||
range_start = start if start else max(0, channel_created)
|
||||
if new_oldest_seconds_epoch < range_start:
|
||||
range_complete = 0.0
|
||||
else:
|
||||
range_complete = end - new_latest_seconds_epoch
|
||||
range_complete = new_oldest_seconds_epoch - range_start
|
||||
|
||||
range_start = max(0, channel_created)
|
||||
range_total = end - range_start
|
||||
if range_total <= 0:
|
||||
range_total = 1
|
||||
@@ -935,7 +936,7 @@ class SlackConnector(
|
||||
)
|
||||
|
||||
checkpoint.seen_thread_ts = list(seen_thread_ts)
|
||||
checkpoint.channel_completion_map[channel["id"]] = new_latest
|
||||
checkpoint.channel_completion_map[channel["id"]] = new_oldest
|
||||
|
||||
# bypass channels where the first set of messages seen are all bots
|
||||
# check at least MIN_BOT_MESSAGE_THRESHOLD messages are in the batch
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -14,6 +15,9 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
@@ -47,14 +51,30 @@ class ZendeskCredentialsNotSetUpError(PermissionError):
|
||||
|
||||
|
||||
class ZendeskClient:
|
||||
def __init__(self, subdomain: str, email: str, token: str):
|
||||
def __init__(
|
||||
self,
|
||||
subdomain: str,
|
||||
email: str,
|
||||
token: str,
|
||||
calls_per_minute: int | None = None,
|
||||
):
|
||||
self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
|
||||
self.auth = (f"{email}/token", token)
|
||||
self.make_request = request_with_rate_limit(self, calls_per_minute)
|
||||
|
||||
|
||||
def request_with_rate_limit(
|
||||
client: ZendeskClient, max_calls_per_minute: int | None = None
|
||||
) -> Callable[[str, dict[str, Any]], dict[str, Any]]:
|
||||
@retry_builder()
|
||||
def make_request(self, endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
@(
|
||||
rate_limit_builder(max_calls=max_calls_per_minute, period=60)
|
||||
if max_calls_per_minute
|
||||
else lambda x: x
|
||||
)
|
||||
def make_request(endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
|
||||
f"{client.base_url}/{endpoint}", auth=client.auth, params=params
|
||||
)
|
||||
|
||||
if response.status_code == 429:
|
||||
@@ -72,6 +92,8 @@ class ZendeskClient:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
return make_request
|
||||
|
||||
|
||||
class ZendeskPageResponse(BaseModel):
|
||||
data: list[dict[str, Any]]
|
||||
@@ -359,11 +381,13 @@ class ZendeskConnector(
|
||||
def __init__(
|
||||
self,
|
||||
content_type: str = "articles",
|
||||
calls_per_minute: int | None = None,
|
||||
) -> None:
|
||||
self.content_type = content_type
|
||||
self.subdomain = ""
|
||||
# Fetch all tags ahead of time
|
||||
self.content_tags: dict[str, str] = {}
|
||||
self.calls_per_minute = calls_per_minute
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# Subdomain is actually the whole URL
|
||||
@@ -375,7 +399,10 @@ class ZendeskConnector(
|
||||
self.subdomain = subdomain
|
||||
|
||||
self.client = ZendeskClient(
|
||||
subdomain, credentials["zendesk_email"], credentials["zendesk_token"]
|
||||
subdomain,
|
||||
credentials["zendesk_email"],
|
||||
credentials["zendesk_token"],
|
||||
calls_per_minute=self.calls_per_minute,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -72,7 +72,6 @@ def _get_answer(
|
||||
else None
|
||||
)
|
||||
research_type = ResearchType(eval_input.get("research_type", "THOUGHTFUL"))
|
||||
print(eval_input)
|
||||
request = prepare_chat_message_request(
|
||||
message_text=eval_input["message"],
|
||||
user=user,
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalationAck
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
from onyx.evals.tracing import setup_braintrust
|
||||
|
||||
|
||||
def setup_session_factory() -> None:
|
||||
@@ -54,6 +55,7 @@ def run_local(
|
||||
EvalationAck: The evaluation result
|
||||
"""
|
||||
setup_session_factory()
|
||||
setup_braintrust()
|
||||
|
||||
if search_permissions_email is None:
|
||||
raise ValueError("search_permissions_email is required for local evaluation")
|
||||
|
||||
@@ -32,7 +32,7 @@ class EvalConfigurationOptions(BaseModel):
|
||||
llm: LLMOverride = LLMOverride(
|
||||
model_provider="Default",
|
||||
model_version="gpt-4.1",
|
||||
temperature=0.5,
|
||||
temperature=0.0,
|
||||
)
|
||||
search_permissions_email: str
|
||||
dataset_name: str
|
||||
|
||||
@@ -31,6 +31,14 @@ except ImportError:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def column_letter_to_index(column_letter: str) -> int:
|
||||
"""Convert Google Sheets column letter (A, B, C, etc.) to 0-based index."""
|
||||
result = 0
|
||||
for char in column_letter.upper():
|
||||
result = result * 26 + (ord(char) - ord("A") + 1)
|
||||
return result - 1
|
||||
|
||||
|
||||
def parse_csv_file(csv_path: str) -> List[Dict[str, Any]]:
|
||||
"""Parse the CSV file and extract relevant records."""
|
||||
records = []
|
||||
@@ -49,38 +57,81 @@ def parse_csv_file(csv_path: str) -> List[Dict[str, Any]]:
|
||||
# Parse the CSV data starting from the data_start line
|
||||
csv_reader = csv.reader(lines[data_start:])
|
||||
|
||||
# Define Google Sheets column references for easy modification
|
||||
SHOULD_USE_COL = "C" # "Should we use it?"
|
||||
QUESTION_COL = "H" # "Question"
|
||||
EXPECTED_DEPTH_COL = "J" # "Expected Depth"
|
||||
CATEGORIES_COL = "M" # "Categories"
|
||||
OPENAI_DEEP_COL = "AA" # "OpenAI Deep Answer"
|
||||
OPENAI_THINKING_COL = "O" # "OpenAI Thinking Answer"
|
||||
|
||||
for row_num, row in enumerate(csv_reader, start=data_start + 1):
|
||||
if len(row) < 13: # Ensure we have enough columns
|
||||
if len(row) < 15: # Ensure we have enough columns
|
||||
continue
|
||||
|
||||
# Extract relevant fields based on CSV structure
|
||||
should_use = row[2].strip().upper() if len(row) > 2 else ""
|
||||
question = row[7].strip() if len(row) > 7 else ""
|
||||
expected_depth = row[9].strip() if len(row) > 9 else ""
|
||||
categories = row[12].strip() if len(row) > 12 else ""
|
||||
# Extract relevant fields using Google Sheets column references
|
||||
should_use = (
|
||||
row[column_letter_to_index(SHOULD_USE_COL)].strip().upper()
|
||||
if len(row) > column_letter_to_index(SHOULD_USE_COL)
|
||||
else ""
|
||||
)
|
||||
question = (
|
||||
row[column_letter_to_index(QUESTION_COL)].strip()
|
||||
if len(row) > column_letter_to_index(QUESTION_COL)
|
||||
else ""
|
||||
)
|
||||
expected_depth = (
|
||||
row[column_letter_to_index(EXPECTED_DEPTH_COL)].strip()
|
||||
if len(row) > column_letter_to_index(EXPECTED_DEPTH_COL)
|
||||
else ""
|
||||
)
|
||||
categories = (
|
||||
row[column_letter_to_index(CATEGORIES_COL)].strip()
|
||||
if len(row) > column_letter_to_index(CATEGORIES_COL)
|
||||
else ""
|
||||
)
|
||||
openai_deep_answer = (
|
||||
row[column_letter_to_index(OPENAI_DEEP_COL)].strip()
|
||||
if len(row) > column_letter_to_index(OPENAI_DEEP_COL)
|
||||
else ""
|
||||
)
|
||||
openai_thinking_answer = (
|
||||
row[column_letter_to_index(OPENAI_THINKING_COL)].strip()
|
||||
if len(row) > column_letter_to_index(OPENAI_THINKING_COL)
|
||||
else ""
|
||||
)
|
||||
|
||||
# Filter records: should_use = TRUE and categories contains "web-only"
|
||||
if should_use == "TRUE" and question: # Ensure question is not empty
|
||||
|
||||
records.extend(
|
||||
[
|
||||
{
|
||||
"question": question
|
||||
+ ". All info is contained in the quesiton. DO NOT ask any clarifying questions.",
|
||||
"research_type": "DEEP",
|
||||
"categories": categories,
|
||||
"expected_depth": expected_depth,
|
||||
"row_number": row_num,
|
||||
},
|
||||
{
|
||||
"question": question,
|
||||
"research_type": "THOUGHTFUL",
|
||||
"categories": categories,
|
||||
"expected_depth": expected_depth,
|
||||
"row_number": row_num,
|
||||
},
|
||||
]
|
||||
)
|
||||
if (
|
||||
should_use == "TRUE" and "web-only" in categories and question
|
||||
): # Ensure question is not empty
|
||||
if expected_depth == "Deep":
|
||||
records.extend(
|
||||
[
|
||||
{
|
||||
"question": question
|
||||
+ ". All info is contained in the quesiton. DO NOT ask any clarifying questions.",
|
||||
"research_type": "DEEP",
|
||||
"categories": categories,
|
||||
"expected_depth": expected_depth,
|
||||
"expected_answer": openai_deep_answer,
|
||||
"row_number": row_num,
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
records.extend(
|
||||
[
|
||||
{
|
||||
"question": question,
|
||||
"research_type": "THOUGHTFUL",
|
||||
"categories": categories,
|
||||
"expected_depth": expected_depth,
|
||||
"expected_answer": openai_thinking_answer,
|
||||
"row_number": row_num,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
return records
|
||||
|
||||
@@ -107,6 +158,7 @@ def create_braintrust_dataset(records: List[Dict[str, Any]], dataset_name: str)
|
||||
print(f"Record {i}/{len(records)}:")
|
||||
print(f" Question: {record['question'][:100]}...")
|
||||
print(f" Research Type: {record['research_type']}")
|
||||
print(f" Expected Answer: {record['expected_answer'][:100]}...")
|
||||
print()
|
||||
return
|
||||
|
||||
@@ -118,11 +170,13 @@ def create_braintrust_dataset(records: List[Dict[str, Any]], dataset_name: str)
|
||||
# Insert records into the dataset
|
||||
for i, record in enumerate(records, 1):
|
||||
record_id = dataset.insert(
|
||||
{"message": record["question"], "research_type": record["research_type"]}
|
||||
{"message": record["question"], "research_type": record["research_type"]},
|
||||
expected=record["expected_answer"],
|
||||
)
|
||||
print(f"Inserted record {i}/{len(records)}: ID {record_id}")
|
||||
print(f" Question: {record['question'][:100]}...")
|
||||
print(f" Research Type: {record['research_type']}")
|
||||
print(f" Expected Answer: {record['expected_answer'][:100]}...")
|
||||
print()
|
||||
|
||||
# Flush to ensure all records are sent
|
||||
|
||||
35
backend/onyx/evals/tracing.py
Normal file
35
backend/onyx/evals/tracing.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Any
|
||||
|
||||
import braintrust
|
||||
from braintrust_langchain import set_global_handler
|
||||
from braintrust_langchain.callbacks import BraintrustCallbackHandler
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import BRAINTRUST_PROJECT
|
||||
|
||||
MASKING_LENGTH = 20000
|
||||
|
||||
|
||||
def _truncate_str(s: str) -> str:
|
||||
tail = MASKING_LENGTH // 5
|
||||
head = MASKING_LENGTH - tail
|
||||
return f"{s[:head]}…{s[-tail:]}[TRUNCATED {len(s)} chars to {MASKING_LENGTH}]"
|
||||
|
||||
|
||||
def _mask(data: Any) -> Any:
|
||||
"""Mask data if it exceeds the maximum length threshold."""
|
||||
if len(str(data)) <= MASKING_LENGTH:
|
||||
return data
|
||||
return _truncate_str(str(data))
|
||||
|
||||
|
||||
def setup_braintrust() -> None:
|
||||
"""Initialize Braintrust logger and set up global callback handler."""
|
||||
|
||||
braintrust.init_logger(
|
||||
project=BRAINTRUST_PROJECT,
|
||||
api_key=BRAINTRUST_API_KEY,
|
||||
)
|
||||
braintrust.set_masking_function(_mask)
|
||||
handler = BraintrustCallbackHandler()
|
||||
set_global_handler(handler)
|
||||
@@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
@@ -17,8 +18,10 @@ from typing import NamedTuple
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import chardet
|
||||
import openpyxl
|
||||
from markitdown import FileConversionException
|
||||
from markitdown import MarkItDown
|
||||
from markitdown import StreamInfo
|
||||
from markitdown import UnsupportedFormatException
|
||||
from PIL import Image
|
||||
from pypdf import PdfReader
|
||||
@@ -30,6 +33,8 @@ from onyx.file_processing.file_validation import TEXT_MIME_TYPE
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import unstructured_to_text
|
||||
from onyx.utils.file_types import PRESENTATION_MIME_TYPE
|
||||
from onyx.utils.file_types import WORD_PROCESSING_MIME_TYPE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -80,6 +85,20 @@ IMAGE_MEDIA_TYPES = [
|
||||
"image/webp",
|
||||
]
|
||||
|
||||
_MARKITDOWN_CONVERTER: MarkItDown | None = None
|
||||
|
||||
KNOWN_OPENPYXL_BUGS = [
|
||||
"Value must be either numerical or a string containing a wildcard",
|
||||
"File contains no valid workbook part",
|
||||
]
|
||||
|
||||
|
||||
def get_markitdown_converter() -> MarkItDown:
|
||||
global _MARKITDOWN_CONVERTER
|
||||
if _MARKITDOWN_CONVERTER is None:
|
||||
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
|
||||
return _MARKITDOWN_CONVERTER
|
||||
|
||||
|
||||
class OnyxExtensionType(IntFlag):
|
||||
Plain = auto()
|
||||
@@ -338,9 +357,11 @@ def docx_to_text_and_images(
|
||||
of avoiding materializing the list of images in memory.
|
||||
The images list returned is empty in this case.
|
||||
"""
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
md = get_markitdown_converter()
|
||||
try:
|
||||
doc = md.convert(to_bytesio(file))
|
||||
doc = md.convert(
|
||||
to_bytesio(file), stream_info=StreamInfo(mimetype=WORD_PROCESSING_MIME_TYPE)
|
||||
)
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
@@ -372,9 +393,12 @@ def docx_to_text_and_images(
|
||||
|
||||
|
||||
def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
md = get_markitdown_converter()
|
||||
stream_info = StreamInfo(
|
||||
mimetype=PRESENTATION_MIME_TYPE, filename=file_name or None, extension=".pptx"
|
||||
)
|
||||
try:
|
||||
presentation = md.convert(to_bytesio(file))
|
||||
presentation = md.convert(to_bytesio(file), stream_info=stream_info)
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
@@ -388,23 +412,69 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
# TODO: switch back to this approach in a few months when markitdown
|
||||
# fixes their handling of excel files
|
||||
|
||||
# md = get_markitdown_converter()
|
||||
# stream_info = StreamInfo(
|
||||
# mimetype=SPREADSHEET_MIME_TYPE, filename=file_name or None, extension=".xlsx"
|
||||
# )
|
||||
# try:
|
||||
# workbook = md.convert(to_bytesio(file), stream_info=stream_info)
|
||||
# except (
|
||||
# BadZipFile,
|
||||
# ValueError,
|
||||
# FileConversionException,
|
||||
# UnsupportedFormatException,
|
||||
# ) as e:
|
||||
# error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
|
||||
# if file_name.startswith("~"):
|
||||
# logger.debug(error_str + " (this is expected for files with ~)")
|
||||
# else:
|
||||
# logger.warning(error_str)
|
||||
# return ""
|
||||
# return workbook.markdown
|
||||
try:
|
||||
workbook = md.convert(to_bytesio(file))
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
FileConversionException,
|
||||
UnsupportedFormatException,
|
||||
) as e:
|
||||
workbook = openpyxl.load_workbook(file, read_only=True)
|
||||
except BadZipFile as e:
|
||||
error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
|
||||
if file_name.startswith("~"):
|
||||
logger.debug(error_str + " (this is expected for files with ~)")
|
||||
else:
|
||||
logger.warning(error_str)
|
||||
return ""
|
||||
except Exception as e:
|
||||
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
|
||||
logger.error(
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
raise e
|
||||
|
||||
return workbook.markdown
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
rows = []
|
||||
num_empty_consecutive_rows = 0
|
||||
for row in sheet.iter_rows(min_row=1, values_only=True):
|
||||
row_str = ",".join(str(cell or "") for cell in row)
|
||||
|
||||
# Only add the row if there are any values in the cells
|
||||
if len(row_str) >= len(row):
|
||||
rows.append(row_str)
|
||||
num_empty_consecutive_rows = 0
|
||||
else:
|
||||
num_empty_consecutive_rows += 1
|
||||
|
||||
if num_empty_consecutive_rows > 100:
|
||||
# handle massive excel sheets with mostly empty cells
|
||||
logger.warning(
|
||||
f"Found {num_empty_consecutive_rows} empty rows in {file_name},"
|
||||
" skipping rest of file"
|
||||
)
|
||||
break
|
||||
sheet_str = "\n".join(rows)
|
||||
text_content.append(sheet_str)
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
|
||||
|
||||
def eml_to_text(file: IO[Any]) -> str:
|
||||
@@ -531,6 +601,23 @@ def extract_text_and_images(
|
||||
Primary new function for the updated connector.
|
||||
Returns structured extraction result with text content, embedded images, and metadata.
|
||||
"""
|
||||
res = _extract_text_and_images(
|
||||
file, file_name, pdf_pass, content_type, image_callback
|
||||
)
|
||||
# Clean up any temporary objects and force garbage collection
|
||||
unreachable = gc.collect()
|
||||
logger.info(f"Unreachable objects: {unreachable}")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _extract_text_and_images(
|
||||
file: IO[Any],
|
||||
file_name: str,
|
||||
pdf_pass: str | None = None,
|
||||
content_type: str | None = None,
|
||||
image_callback: Callable[[bytes, str], None] | None = None,
|
||||
) -> ExtractionResult:
|
||||
file.seek(0)
|
||||
|
||||
if get_unstructured_api_key():
|
||||
@@ -556,7 +643,6 @@ def extract_text_and_images(
|
||||
# Default processing
|
||||
try:
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
# docx example for embedded images
|
||||
if extension == ".docx":
|
||||
text_content, images = docx_to_text_and_images(
|
||||
|
||||
@@ -26,6 +26,7 @@ from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
@@ -42,7 +43,6 @@ from onyx.server.utils import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
@@ -50,6 +50,9 @@ logger = setup_logger()
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
|
||||
if BRAINTRUST_ENABLED:
|
||||
litellm.callbacks = ["braintrust"]
|
||||
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
|
||||
@@ -2,10 +2,10 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Literal
|
||||
|
||||
from braintrust import traced
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langsmith.run_helpers import traceable
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
@@ -53,6 +53,11 @@ def log_prompt(prompt: LanguageModelInput) -> None:
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
elif isinstance(msg, dict):
|
||||
log_msg = msg.get("content", "")
|
||||
if "files" in msg:
|
||||
log_msg = f"{log_msg}\nfiles: {msg['files']}"
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
@@ -87,7 +92,7 @@ class LLM(abc.ABC):
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
log_prompt(prompt)
|
||||
|
||||
@traceable(run_type="llm")
|
||||
@traced(name="invoke llm", type="llm")
|
||||
def invoke(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
@@ -121,7 +126,7 @@ class LLM(abc.ABC):
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@traceable(run_type="llm")
|
||||
@traced(name="stream llm", type="llm")
|
||||
def stream(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
|
||||
@@ -35,6 +35,7 @@ from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
@@ -51,6 +52,7 @@ from onyx.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
from onyx.db.engine.connection_warmup import warm_up_connections
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.evals.tracing import setup_braintrust
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.api_key.api import router as api_key_router
|
||||
from onyx.server.auth_check import check_router_auth
|
||||
@@ -252,6 +254,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
logger.notice("Generative AI Q&A disabled")
|
||||
|
||||
if BRAINTRUST_ENABLED:
|
||||
setup_braintrust()
|
||||
logger.notice("Braintrust tracing initialized")
|
||||
|
||||
# fill up Postgres connection pools
|
||||
await warm_up_connections()
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from copy import copy
|
||||
|
||||
from tokenizers import Encoding # type: ignore
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
@@ -16,10 +15,8 @@ from shared_configs.enums import EmbeddingProvider
|
||||
TRIM_SEP_PAT = "\n... {n} tokens removed...\n"
|
||||
|
||||
logger = setup_logger()
|
||||
transformer_logging.set_verbosity_error()
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
|
||||
class BaseTokenizer(ABC):
|
||||
|
||||
@@ -97,6 +97,15 @@ some sub-answers, but the question sent to the graph must be a question that can
|
||||
entity and relationship types and their attributes that will be communicated to you.
|
||||
"""
|
||||
|
||||
TOOL_DESCRIPTION[DRPath.PYTHON_TOOL] = (
|
||||
"""\
|
||||
This tool executes Python code within an isolated Code Interpreter environment. \
|
||||
Use it when the user explicitly requests code execution, data analysis, statistical computations, \
|
||||
or operations on uploaded files. The tool has access to any files the user supplied and can generate \
|
||||
new artifacts (for example, charts or datasets) that can be downloaded afterwards.\
|
||||
""".strip()
|
||||
)
|
||||
|
||||
TOOL_DESCRIPTION[
|
||||
DRPath.CLOSER
|
||||
] = f"""\
|
||||
@@ -948,6 +957,31 @@ ANSWER:
|
||||
)
|
||||
|
||||
|
||||
PYTHON_TOOL_USE_RESPONSE_PROMPT = PromptTemplate(
|
||||
f"""\
|
||||
Here is the base question that ultimately needs to be answered:
|
||||
{SEPARATOR_LINE}
|
||||
---base_question---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
You just used the `python` tool in an attempt to answer the base question. The respone from the \
|
||||
tool is:
|
||||
|
||||
{SEPARATOR_LINE}
|
||||
---tool_response---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
If any files were uploaded by the user, they are attached. If any files were generated by the \
|
||||
`python` tool, they are also attached, and prefixed with `generated_`.
|
||||
|
||||
|
||||
Please respond with a concise answer to the base question. If you believe additionl `python` \
|
||||
tool calls are necessary, please describe what HIGH LEVEL TASK we would like to accomplish next. \
|
||||
Do NOT describe any specific Python code that we would like to execute.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
TEST_INFO_COMPLETE_PROMPT = PromptTemplate(
|
||||
f"""\
|
||||
You are an expert at trying to determine whether \
|
||||
|
||||
@@ -10,6 +10,9 @@ from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.python_tool import (
|
||||
PythonTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
@@ -20,7 +23,12 @@ logger = setup_logger()
|
||||
|
||||
|
||||
BUILT_IN_TOOL_TYPES = Union[
|
||||
SearchTool, ImageGenerationTool, WebSearchTool, KnowledgeGraphTool, OktaProfileTool
|
||||
SearchTool,
|
||||
ImageGenerationTool,
|
||||
WebSearchTool,
|
||||
KnowledgeGraphTool,
|
||||
OktaProfileTool,
|
||||
PythonTool,
|
||||
]
|
||||
|
||||
# same as d09fc20a3c66_seed_builtin_tools.py
|
||||
@@ -30,6 +38,7 @@ BUILT_IN_TOOL_MAP: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {
|
||||
WebSearchTool.__name__: WebSearchTool,
|
||||
KnowledgeGraphTool.__name__: KnowledgeGraphTool,
|
||||
OktaProfileTool.__name__: OktaProfileTool,
|
||||
PythonTool.__name__: PythonTool,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@ from onyx.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from onyx.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from onyx.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@@ -55,6 +58,9 @@ from onyx.tools.tool_implementations.mcp.mcp_tool import MCPTool
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.python.python_tool import (
|
||||
PythonTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
@@ -193,7 +199,9 @@ def construct_tools(
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs"""
|
||||
"""Constructs tools based on persona configuration and available APIs.
|
||||
|
||||
Will simply skip tools that are not allowed/available."""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
mcp_tool_cache: dict[int, dict[int, MCPTool]] = {}
|
||||
@@ -210,6 +218,21 @@ def construct_tools(
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.in_code_tool_id)
|
||||
|
||||
try:
|
||||
tool_is_available = tool_cls.is_available(db_session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed checking availability for tool %s", tool_cls.__name__
|
||||
)
|
||||
tool_is_available = False
|
||||
|
||||
if not tool_is_available:
|
||||
logger.debug(
|
||||
"Skipping tool %s because it is not available",
|
||||
tool_cls.__name__,
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle Search Tool
|
||||
if (
|
||||
tool_cls.__name__ == SearchTool.__name__
|
||||
@@ -263,6 +286,27 @@ def construct_tools(
|
||||
)
|
||||
]
|
||||
|
||||
# Handle Python Tool
|
||||
elif tool_cls.__name__ == PythonTool.__name__:
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
raise ValueError(
|
||||
"Code Interpreter base URL must be configured to use the Python tool"
|
||||
)
|
||||
|
||||
available_files = []
|
||||
if search_tool_config and search_tool_config.latest_query_files:
|
||||
available_files = search_tool_config.latest_query_files
|
||||
|
||||
tool_dict[db_tool_model.id] = [
|
||||
PythonTool(
|
||||
tool_id=db_tool_model.id,
|
||||
base_url=CODE_INTERPRETER_BASE_URL,
|
||||
default_timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
request_timeout_seconds=CODE_INTERPRETER_REQUEST_TIMEOUT_SECONDS,
|
||||
available_files=available_files,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle Internet Search Tool
|
||||
elif tool_cls.__name__ == WebSearchTool.__name__:
|
||||
if not internet_search_tool_config:
|
||||
|
||||
644
backend/onyx/tools/tool_implementations/python/python_tool.py
Normal file
644
backend/onyx/tools/tool_implementations/python/python_tool.py
Normal file
@@ -0,0 +1,644 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.tools.base_tool import BaseTool
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
PYTHON_TOOL_RESPONSE_ID = "python_tool_result"
|
||||
_EXECUTE_PATH = "/execute"
|
||||
_DEFAULT_RESULT_CHAR_LIMIT = 1200
|
||||
|
||||
|
||||
class PythonToolArgs(BaseModel):
|
||||
code: str = Field(
|
||||
...,
|
||||
description="Python source code to execute inside the Code Interpreter service",
|
||||
)
|
||||
stdin: str | None = Field(
|
||||
default=None,
|
||||
description="Optional standard input payload supplied to the process",
|
||||
)
|
||||
timeout_ms: int | None = Field(
|
||||
default=None,
|
||||
ge=1_000,
|
||||
description="Optional execution timeout override in milliseconds",
|
||||
)
|
||||
|
||||
|
||||
class ExecuteRequestFilePayload(BaseModel):
|
||||
path: str
|
||||
content_base64: str
|
||||
|
||||
|
||||
class ExecuteRequestPayload(BaseModel):
|
||||
code: str
|
||||
timeout_ms: int
|
||||
stdin: str | None = None
|
||||
files: list[ExecuteRequestFilePayload] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkspaceFilePayload(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
path: str
|
||||
kind: str = Field(default="file")
|
||||
content_base64: str | None = None
|
||||
mime_type: str | None = None
|
||||
size_bytes: int | None = Field(default=None, alias="size")
|
||||
|
||||
|
||||
class PythonToolArtifact(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
path: str
|
||||
kind: str
|
||||
file_id: str | None = None
|
||||
display_name: str | None = None
|
||||
mime_type: str | None = None
|
||||
size_bytes: int | None = None
|
||||
error: str | None = None
|
||||
chat_file_type: ChatFileType | None = None
|
||||
|
||||
|
||||
class PythonToolResult(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
stdout: str
|
||||
stderr: str | None = None
|
||||
exit_code: int | None = None
|
||||
execution_time_ms: int | None = None
|
||||
timeout_ms: int
|
||||
input_files: list[dict[str, Any]] = Field(default_factory=list)
|
||||
artifacts: list[PythonToolArtifact] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
_PYTHON_DESCRIPTION = """
|
||||
When you send a message containing Python code to python, it will be executed in an isolated environment.
|
||||
|
||||
The code you write will be placed into a file called `__main__.py` and run like `python __main__.py`. \
|
||||
All other files present in the conversation will also be present in this same directory. E.g. if the \
|
||||
user has uploaded three files, `analytics.csv`, `word.txt`, and `cat.png`, the directory structure will look like:
|
||||
|
||||
workspace/
|
||||
__main__.py
|
||||
analytics.csv
|
||||
word.txt
|
||||
cat.png
|
||||
|
||||
python will respond with the stdout of the `__main__.py` script as well as any new/edited files in the `workspace` directory. \
|
||||
That means, if you want to get access to some result of the script you can either: 1) `print(<result>)` or 2) save \
|
||||
a result to a file in the `workspace` directory.
|
||||
|
||||
Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
|
||||
|
||||
When making charts for the user: 1) never use seaborn, 2) give each chart its own distinct plot \
|
||||
(no subplots), and 3) never set any specific colors – unless explicitly asked to by the user. \
|
||||
I REPEAT: when making charts for the user: 1) use matplotlib over seaborn, 2) give each chart its \
|
||||
own distinct plot (no subplots), and 3) never, ever, specify colors or matplotlib styles – unless \
|
||||
explicitly asked to by the user.
|
||||
""".strip()
|
||||
|
||||
|
||||
class PythonTool(BaseTool):
|
||||
_NAME = "run_code_interpreter"
|
||||
_DESCRIPTION = _PYTHON_DESCRIPTION
|
||||
_DISPLAY_NAME = "Code Interpreter"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_id: int,
|
||||
base_url: str,
|
||||
default_timeout_ms: int,
|
||||
request_timeout_seconds: int,
|
||||
available_files: list[InMemoryChatFile] | None = None,
|
||||
) -> None:
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"Code Interpreter base URL must be configured to use the Python tool"
|
||||
)
|
||||
|
||||
self._id = tool_id
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._default_timeout_ms = default_timeout_ms
|
||||
self._request_timeout_seconds = request_timeout_seconds
|
||||
self._available_files: dict[str, InMemoryChatFile] = {
|
||||
chat_file.file_id: chat_file for chat_file in available_files or []
|
||||
}
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
return bool(CODE_INTERPRETER_BASE_URL)
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python source code to execute",
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
# Not supported for non-tool calling LLMs
|
||||
return None
|
||||
|
||||
def run(
|
||||
self,
|
||||
override_kwargs: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
raw_requested_files = kwargs.pop("files", None)
|
||||
|
||||
try:
|
||||
parsed_args = PythonToolArgs.model_validate(kwargs)
|
||||
except ValidationError as exc:
|
||||
logger.exception("Invalid arguments passed to PythonTool")
|
||||
raise ValueError("Invalid arguments supplied to Code Interpreter") from exc
|
||||
|
||||
timeout_ms = parsed_args.timeout_ms or self._default_timeout_ms
|
||||
override_files = self._coerce_override_files(override_kwargs)
|
||||
requested_files = self._load_requested_files(raw_requested_files)
|
||||
request_files, input_metadata = self._prepare_request_files(
|
||||
override_files + requested_files
|
||||
)
|
||||
|
||||
request_payload = ExecuteRequestPayload(
|
||||
code=parsed_args.code,
|
||||
timeout_ms=timeout_ms,
|
||||
stdin=parsed_args.stdin,
|
||||
files=request_files,
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self._base_url}{_EXECUTE_PATH}",
|
||||
json=request_payload.model_dump(exclude_none=True),
|
||||
timeout=self._request_timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as exc:
|
||||
logger.exception("Code Interpreter execution failed")
|
||||
raise ValueError(
|
||||
"Failed to reach the Code Interpreter service. Please try again later."
|
||||
) from exc
|
||||
|
||||
response_data = self._parse_response(response)
|
||||
artifacts = self._persist_artifacts(response_data.get("files", []))
|
||||
|
||||
python_result = PythonToolResult(
|
||||
stdout=self._ensure_text(response_data.get("stdout", "")),
|
||||
stderr=self._optional_text(response_data.get("stderr")),
|
||||
exit_code=self._safe_int(response_data.get("exit_code")),
|
||||
execution_time_ms=self._safe_int(
|
||||
response_data.get("execution_time_ms")
|
||||
or response_data.get("duration_ms")
|
||||
),
|
||||
timeout_ms=timeout_ms,
|
||||
input_files=input_metadata,
|
||||
artifacts=artifacts,
|
||||
metadata=self._extract_metadata(response_data.get("metadata")),
|
||||
)
|
||||
|
||||
yield ToolResponse(id=PYTHON_TOOL_RESPONSE_ID, response=python_result)
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
result = self._extract_result(args)
|
||||
sections: list[str] = []
|
||||
|
||||
if result.stdout:
|
||||
sections.append(self._format_section("stdout", result.stdout))
|
||||
if result.stderr:
|
||||
sections.append(self._format_section("stderr", result.stderr))
|
||||
|
||||
if result.artifacts:
|
||||
file_lines = []
|
||||
for artifact in result.artifacts:
|
||||
display_name = artifact.display_name or Path(artifact.path).name
|
||||
if artifact.error:
|
||||
file_lines.append(f"- {display_name}: {artifact.error}")
|
||||
elif artifact.file_id:
|
||||
file_lines.append(
|
||||
f"- {display_name} (file_id={artifact.file_id}, mime={artifact.mime_type or 'unknown'})"
|
||||
)
|
||||
else:
|
||||
file_lines.append(f"- {display_name}")
|
||||
sections.append("Generated files:\n" + "\n".join(file_lines))
|
||||
|
||||
return (
|
||||
"\n\n".join(sections) if sections else "Execution completed with no output."
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
return self._extract_result(args).model_dump()
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
tool_call_summary: ToolCallSummary,
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> AnswerPromptBuilder:
|
||||
updated_prompt_builder = super().build_next_prompt(
|
||||
prompt_builder, tool_call_summary, tool_responses, using_tool_calling_llm
|
||||
)
|
||||
|
||||
result = self._extract_result(tool_responses)
|
||||
if result.artifacts:
|
||||
existing_ids = {
|
||||
file.file_id for file in updated_prompt_builder.raw_user_uploaded_files
|
||||
}
|
||||
for artifact in result.artifacts:
|
||||
if not artifact.file_id:
|
||||
continue
|
||||
chat_file = self._available_files.get(artifact.file_id)
|
||||
if not chat_file or chat_file.file_id in existing_ids:
|
||||
continue
|
||||
if chat_file.file_type.is_text_file():
|
||||
updated_prompt_builder.raw_user_uploaded_files.append(chat_file)
|
||||
existing_ids.add(chat_file.file_id)
|
||||
return updated_prompt_builder
|
||||
|
||||
def _load_requested_files(self, raw_files: Any) -> list[InMemoryChatFile]:
|
||||
if not raw_files:
|
||||
return []
|
||||
|
||||
if not isinstance(raw_files, list):
|
||||
logger.warning(
|
||||
"Ignoring non-list 'files' argument passed to PythonTool: %r",
|
||||
type(raw_files),
|
||||
)
|
||||
return []
|
||||
|
||||
loaded_files: list[InMemoryChatFile] = []
|
||||
for descriptor in raw_files:
|
||||
file_id: str | None = None
|
||||
desired_name: str | None = None
|
||||
|
||||
if isinstance(descriptor, dict):
|
||||
raw_id = descriptor.get("id") or descriptor.get("file_id")
|
||||
if raw_id is not None:
|
||||
file_id = str(raw_id)
|
||||
desired_name = (
|
||||
descriptor.get("path")
|
||||
or descriptor.get("name")
|
||||
or descriptor.get("filename")
|
||||
)
|
||||
elif isinstance(descriptor, str):
|
||||
file_id = descriptor
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping unsupported file descriptor passed to PythonTool: %r",
|
||||
descriptor,
|
||||
)
|
||||
continue
|
||||
|
||||
if not file_id:
|
||||
logger.warning("Skipping file descriptor without an id: %r", descriptor)
|
||||
continue
|
||||
|
||||
chat_file = self._available_files.get(file_id)
|
||||
if not chat_file:
|
||||
chat_file = self._load_file_from_store(file_id)
|
||||
|
||||
if not chat_file:
|
||||
logger.warning("Requested file '%s' could not be found", file_id)
|
||||
continue
|
||||
|
||||
if desired_name and desired_name != chat_file.filename:
|
||||
chat_file = InMemoryChatFile(
|
||||
file_id=chat_file.file_id,
|
||||
content=chat_file.content,
|
||||
filename=desired_name,
|
||||
file_type=chat_file.file_type,
|
||||
)
|
||||
|
||||
loaded_files.append(chat_file)
|
||||
|
||||
return loaded_files
|
||||
|
||||
def _load_file_from_store(self, file_id: str) -> InMemoryChatFile | None:
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.exception("Failed to obtain file store instance: %s", exc)
|
||||
return None
|
||||
|
||||
try:
|
||||
file_record = file_store.read_file_record(file_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to read file record for '%s' from file store", file_id
|
||||
)
|
||||
return None
|
||||
|
||||
if not file_record:
|
||||
logger.warning("No file record found for requested file '%s'", file_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to read file content for '%s' from file store", file_id
|
||||
)
|
||||
return None
|
||||
|
||||
with closing(file_io):
|
||||
data = file_io.read() if hasattr(file_io, "read") else file_io
|
||||
|
||||
if not isinstance(data, (bytes, bytearray)):
|
||||
logger.warning("File store returned non-bytes payload for '%s'", file_id)
|
||||
return None
|
||||
|
||||
mime_type = (
|
||||
getattr(file_record, "file_type", None) or "application/octet-stream"
|
||||
)
|
||||
display_name = getattr(file_record, "display_name", None) or file_id
|
||||
chat_file_type = mime_type_to_chat_file_type(mime_type)
|
||||
|
||||
return InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=bytes(data),
|
||||
filename=display_name,
|
||||
file_type=chat_file_type,
|
||||
)
|
||||
|
||||
def _coerce_override_files(
|
||||
self, override_kwargs: dict[str, Any] | None
|
||||
) -> list[InMemoryChatFile]:
|
||||
if not override_kwargs:
|
||||
return []
|
||||
|
||||
raw_files = override_kwargs.get("files")
|
||||
if not raw_files:
|
||||
return []
|
||||
|
||||
coerced_files: list[InMemoryChatFile] = []
|
||||
for raw_file in raw_files:
|
||||
if isinstance(raw_file, InMemoryChatFile):
|
||||
coerced_files.append(raw_file)
|
||||
else:
|
||||
logger.warning(
|
||||
"Ignoring non InMemoryChatFile entry passed to PythonTool override: %r",
|
||||
type(raw_file),
|
||||
)
|
||||
return coerced_files
|
||||
|
||||
def _prepare_request_files(
|
||||
self, chat_files: list[InMemoryChatFile]
|
||||
) -> tuple[list[ExecuteRequestFilePayload], list[dict[str, Any]]]:
|
||||
request_files: list[ExecuteRequestFilePayload] = []
|
||||
input_metadata: list[dict[str, Any]] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for chat_file in chat_files:
|
||||
file_id = str(chat_file.file_id)
|
||||
if file_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(file_id)
|
||||
|
||||
self._available_files[file_id] = chat_file
|
||||
target_path = self._resolve_target_path(chat_file)
|
||||
|
||||
request_files.append(
|
||||
ExecuteRequestFilePayload(
|
||||
path=target_path,
|
||||
content_base64=base64.b64encode(chat_file.content).decode("utf-8"),
|
||||
)
|
||||
)
|
||||
|
||||
input_metadata.append(
|
||||
{
|
||||
"id": file_id,
|
||||
"path": target_path,
|
||||
"name": chat_file.filename or file_id,
|
||||
"chat_file_type": chat_file.file_type.value,
|
||||
}
|
||||
)
|
||||
|
||||
return request_files, input_metadata
|
||||
|
||||
def _persist_artifacts(self, files: list[Any]) -> list[PythonToolArtifact]:
|
||||
artifacts: list[PythonToolArtifact] = []
|
||||
if not files:
|
||||
return artifacts
|
||||
|
||||
file_store = get_default_file_store()
|
||||
|
||||
for raw_file in files:
|
||||
try:
|
||||
workspace_file = WorkspaceFilePayload.model_validate(raw_file)
|
||||
except ValidationError as exc:
|
||||
logger.warning("Skipping malformed workspace file entry: %s", exc)
|
||||
continue
|
||||
|
||||
artifact = PythonToolArtifact(
|
||||
path=workspace_file.path, kind=workspace_file.kind
|
||||
)
|
||||
if workspace_file.kind != "file" or not workspace_file.content_base64:
|
||||
artifacts.append(artifact)
|
||||
continue
|
||||
|
||||
try:
|
||||
binary = base64.b64decode(workspace_file.content_base64)
|
||||
except binascii.Error as exc:
|
||||
artifact.error = f"Failed to decode base64 content: {exc}"
|
||||
artifacts.append(artifact)
|
||||
continue
|
||||
|
||||
mime_type = self._infer_mime_type(workspace_file, binary)
|
||||
display_name = Path(workspace_file.path).name or workspace_file.path
|
||||
|
||||
try:
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(binary),
|
||||
display_name=display_name,
|
||||
file_origin=FileOrigin.GENERATED_REPORT,
|
||||
file_type=mime_type,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Failed to persist Code Interpreter artifact '%s'",
|
||||
workspace_file.path,
|
||||
)
|
||||
artifact.error = f"Failed to persist artifact: {exc}"
|
||||
artifacts.append(artifact)
|
||||
continue
|
||||
|
||||
chat_file_type = mime_type_to_chat_file_type(mime_type)
|
||||
chat_file = InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
content=binary,
|
||||
filename=display_name,
|
||||
file_type=chat_file_type,
|
||||
)
|
||||
self._available_files[file_id] = chat_file
|
||||
|
||||
artifact.file_id = file_id
|
||||
artifact.display_name = display_name
|
||||
artifact.mime_type = mime_type
|
||||
artifact.size_bytes = len(binary)
|
||||
artifact.chat_file_type = chat_file_type
|
||||
artifacts.append(artifact)
|
||||
|
||||
return artifacts
|
||||
|
||||
def _parse_response(self, response: requests.Response) -> dict[str, Any]:
|
||||
try:
|
||||
response_data = response.json()
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.exception("Code Interpreter returned invalid JSON")
|
||||
raise ValueError(
|
||||
"Code Interpreter returned an invalid JSON response"
|
||||
) from exc
|
||||
|
||||
error_payload = response_data.get("error")
|
||||
if error_payload:
|
||||
if isinstance(error_payload, dict):
|
||||
message = error_payload.get("message") or json.dumps(error_payload)
|
||||
else:
|
||||
message = str(error_payload)
|
||||
raise ValueError(f"Code Interpreter reported an error: {message}")
|
||||
|
||||
return response_data
|
||||
|
||||
def _extract_result(self, responses: tuple[ToolResponse, ...]) -> PythonToolResult:
|
||||
for response in responses:
|
||||
if response.id == PYTHON_TOOL_RESPONSE_ID:
|
||||
if isinstance(response.response, PythonToolResult):
|
||||
return response.response
|
||||
return PythonToolResult.model_validate(response.response)
|
||||
raise ValueError("No python tool result found in tool responses")
|
||||
|
||||
def _resolve_target_path(self, chat_file: InMemoryChatFile) -> str:
|
||||
if chat_file.filename:
|
||||
return Path(chat_file.filename).name
|
||||
return chat_file.file_id
|
||||
|
||||
def _infer_mime_type(
|
||||
self, workspace_file: WorkspaceFilePayload, binary: bytes
|
||||
) -> str:
|
||||
if workspace_file.mime_type:
|
||||
return workspace_file.mime_type
|
||||
|
||||
guess, _ = mimetypes.guess_type(workspace_file.path)
|
||||
if guess:
|
||||
return guess
|
||||
|
||||
if self._looks_like_text(binary):
|
||||
return "text/plain"
|
||||
|
||||
return "application/octet-stream"
|
||||
|
||||
@staticmethod
|
||||
def _ensure_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _optional_text(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text_value = str(value)
|
||||
return text_value if text_value else None
|
||||
|
||||
@staticmethod
|
||||
def _safe_int(value: Any) -> int | None:
|
||||
try:
|
||||
return int(value) if value is not None else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_metadata(value: Any) -> dict[str, Any] | None:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_text(payload: bytes, sample_size: int = 2048) -> bool:
|
||||
sample = payload[:sample_size]
|
||||
try:
|
||||
sample.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _format_section(title: str, body: str) -> str:
|
||||
truncated = (
|
||||
body
|
||||
if len(body) <= _DEFAULT_RESULT_CHAR_LIMIT
|
||||
else body[: _DEFAULT_RESULT_CHAR_LIMIT - 3] + "..."
|
||||
)
|
||||
return f"{title}:\n{truncated}"
|
||||
@@ -1,3 +1,16 @@
|
||||
PRESENTATION_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
|
||||
SPREADSHEET_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
WORD_PROCESSING_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
)
|
||||
PDF_MIME_TYPE = "application/pdf"
|
||||
|
||||
|
||||
class UploadMimeTypes:
|
||||
IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||
CSV_MIME_TYPES = {"text/csv"}
|
||||
@@ -13,10 +26,10 @@ class UploadMimeTypes:
|
||||
"application/x-yaml",
|
||||
}
|
||||
DOCUMENT_MIME_TYPES = {
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
PDF_MIME_TYPE,
|
||||
WORD_PROCESSING_MIME_TYPE,
|
||||
PRESENTATION_MIME_TYPE,
|
||||
SPREADSHEET_MIME_TYPE,
|
||||
"message/rfc822",
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
53
backend/onyx/utils/memory_logger.py
Normal file
53
backend/onyx/utils/memory_logger.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# # leaving this here for future mem debugging efforts
|
||||
# import os
|
||||
# from typing import Any
|
||||
|
||||
# import psutil
|
||||
# from pympler import asizeof
|
||||
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
# logger = setup_logger()
|
||||
|
||||
#
|
||||
# def log_memory_usage(
|
||||
# label: str,
|
||||
# specific_object: Any = None,
|
||||
# object_label: str = "",
|
||||
# ) -> None:
|
||||
# """Log current process memory usage and optionally the size of a specific object.
|
||||
|
||||
# Args:
|
||||
# label: A descriptive label for the current location/operation in code
|
||||
# specific_object: Optional object to measure the size of
|
||||
# object_label: Optional label describing the specific object
|
||||
# """
|
||||
# try:
|
||||
# # Get current process memory info
|
||||
# process = psutil.Process(os.getpid())
|
||||
# memory_info = process.memory_info()
|
||||
|
||||
# # Convert to MB for readability
|
||||
# rss_mb = memory_info.rss / (1024 * 1024)
|
||||
# vms_mb = memory_info.vms / (1024 * 1024)
|
||||
|
||||
# log_parts = [f"MEMORY[{label}]", f"RSS: {rss_mb:.2f}MB", f"VMS: {vms_mb:.2f}MB"]
|
||||
|
||||
# # Add object size if provided
|
||||
# if specific_object is not None:
|
||||
# try:
|
||||
# # recursively calculate the size of the object
|
||||
# obj_size = asizeof.asizeof(specific_object)
|
||||
# obj_size_mb = obj_size / (1024 * 1024)
|
||||
# obj_desc = f"[{object_label}]" if object_label else "[object]"
|
||||
# log_parts.append(f"OBJ{obj_desc}: {obj_size_mb:.2f}MB")
|
||||
# except Exception as e:
|
||||
# log_parts.append(f"OBJ_SIZE_ERROR: {str(e)}")
|
||||
|
||||
# logger.info(" | ".join(log_parts))
|
||||
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to log memory usage for {label}: {str(e)}")
|
||||
|
||||
# For example, use this like:
|
||||
# log_memory_usage("my_operation", my_large_object, "my_large_object")
|
||||
@@ -13,6 +13,10 @@ exclude = [
|
||||
module = "alembic.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "braintrust_langchain.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic_tenants.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
@@ -51,6 +51,7 @@ nltk==3.9.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.99.5
|
||||
openpyxl==3.1.5
|
||||
passlib==1.7.4
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
@@ -60,6 +61,7 @@ pyairtable==3.0.1
|
||||
pycryptodome==3.19.1
|
||||
pydantic==2.11.7
|
||||
PyGithub==2.5.0
|
||||
pympler==1.1
|
||||
python-dateutil==2.8.2
|
||||
python-gitlab==5.6.0
|
||||
python-pptx==0.6.23
|
||||
@@ -75,7 +77,6 @@ requests==2.32.5
|
||||
requests-oauthlib==1.3.1
|
||||
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
|
||||
rfc3986==1.5.0
|
||||
setfit==1.1.1
|
||||
simple-salesforce==1.12.6
|
||||
slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.15
|
||||
@@ -84,7 +85,7 @@ supervisor==4.2.5
|
||||
RapidFuzz==3.13.0
|
||||
tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
transformers==4.49.0
|
||||
types-openpyxl==3.1.5.20250919
|
||||
unstructured==0.15.1
|
||||
unstructured-client==0.25.4
|
||||
uvicorn==0.35.0
|
||||
@@ -106,3 +107,4 @@ voyageai==0.2.3
|
||||
cohere==5.6.1
|
||||
exa_py==1.15.4
|
||||
braintrust==0.2.6
|
||||
braintrust-langchain==0.0.4
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
|
||||
|
||||
def test_answer_with_only_anthropic_provider(
|
||||
db_session: Session,
|
||||
full_deployment_setup: None,
|
||||
mock_external_deps: None,
|
||||
) -> None:
|
||||
"""Ensure chat still streams answers when only an Anthropic provider is configured."""
|
||||
|
||||
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
assert anthropic_api_key, "ANTHROPIC_API_KEY environment variable must be set"
|
||||
|
||||
# Drop any existing providers so that only Anthropic is available.
|
||||
for provider in fetch_existing_llm_providers(db_session):
|
||||
remove_llm_provider(db_session, provider.id)
|
||||
|
||||
anthropic_model = "claude-3-5-sonnet-20240620"
|
||||
provider_name = f"anthropic-test-{uuid4().hex}"
|
||||
|
||||
anthropic_provider = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider="anthropic",
|
||||
api_key=anthropic_api_key,
|
||||
default_model_name=anthropic_model,
|
||||
fast_default_model_name=anthropic_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name=anthropic_model, is_visible=True)
|
||||
],
|
||||
api_key_changed=True,
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
try:
|
||||
update_default_provider(anthropic_provider.id, db_session)
|
||||
|
||||
test_user = create_test_user(db_session, email_prefix="anthropic_only")
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="Anthropic only chat",
|
||||
user_id=test_user.id,
|
||||
persona_id=0,
|
||||
)
|
||||
|
||||
chat_request = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=None,
|
||||
message="hello",
|
||||
file_descriptors=[],
|
||||
search_doc_ids=None,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
)
|
||||
|
||||
response_stream: list[AnswerStreamPart] = []
|
||||
for packet in stream_chat_message_objects(
|
||||
new_msg_req=chat_request,
|
||||
user=test_user,
|
||||
db_session=db_session,
|
||||
):
|
||||
response_stream.append(packet)
|
||||
|
||||
assert response_stream, "Should receive streamed packets"
|
||||
assert not any(
|
||||
isinstance(packet, StreamingError) for packet in response_stream
|
||||
), "No streaming errors expected with Anthropic provider"
|
||||
|
||||
has_message_id = any(
|
||||
isinstance(packet, MessageResponseIDInfo) for packet in response_stream
|
||||
)
|
||||
assert has_message_id, "Should include reserved assistant message ID"
|
||||
|
||||
has_message_start = any(
|
||||
isinstance(packet, Packet) and isinstance(packet.obj, MessageStart)
|
||||
for packet in response_stream
|
||||
)
|
||||
assert has_message_start, "Stream should have a MessageStart packet"
|
||||
|
||||
has_message_delta = any(
|
||||
isinstance(packet, Packet) and isinstance(packet.obj, MessageDelta)
|
||||
for packet in response_stream
|
||||
)
|
||||
assert has_message_delta, "Stream should have a MessageDelta packet"
|
||||
|
||||
finally:
|
||||
remove_llm_provider(db_session, anthropic_provider.id)
|
||||
@@ -87,6 +87,8 @@ COPY ./requirements/dev.txt /tmp/dev-requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
-r /tmp/dev-requirements.txt
|
||||
COPY ./tests/integration /app/tests/integration
|
||||
# copies all files, but not folders, in the tests directory
|
||||
COPY ./tests/* /app/tests/
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ def mock_zendesk_client() -> MagicMock:
|
||||
mock = MagicMock(spec=ZendeskClient)
|
||||
mock.base_url = "https://test.zendesk.com/api/v2"
|
||||
mock.auth = ("test@example.com/token", "test_token")
|
||||
mock.make_request = MagicMock()
|
||||
return mock
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeTime:
|
||||
"""A controllable time module replacement.
|
||||
|
||||
- monotonic(): returns an internal counter (seconds)
|
||||
- sleep(x): advances the internal counter by x seconds
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._t = 0.0
|
||||
|
||||
def monotonic(self) -> float: # type: ignore[override]
|
||||
return self._t
|
||||
|
||||
def sleep(self, seconds: float) -> None: # type: ignore[override]
|
||||
# advance time without real waiting
|
||||
self._t += float(seconds)
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, json_payload: Dict[str, Any], status_code: int = 200) -> None:
|
||||
self._json = json_payload
|
||||
self.status_code = status_code
|
||||
self.headers: Dict[str, str] = {}
|
||||
|
||||
def json(self) -> Dict[str, Any]:
|
||||
return self._json
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
# simulate OK
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_zendesk_client_per_minute_rate_limiting(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Import here to allow monkeypatching modules safely
|
||||
from onyx.connectors.zendesk.connector import ZendeskClient
|
||||
import onyx.connectors.cross_connector_utils.rate_limit_wrapper as rlw
|
||||
import onyx.connectors.zendesk.connector as zendesk_mod
|
||||
|
||||
fake_time = _FakeTime()
|
||||
|
||||
# Patch time in both the rate limit wrapper and the zendesk connector module
|
||||
monkeypatch.setattr(rlw, "time", fake_time, raising=True)
|
||||
monkeypatch.setattr(zendesk_mod, "time", fake_time, raising=True)
|
||||
|
||||
# Stub out requests.get to avoid network and return a minimal valid payload
|
||||
calls: list[str] = []
|
||||
|
||||
def _fake_get(url: str, auth: Any, params: Dict[str, Any]) -> _FakeResponse:
|
||||
calls.append(url)
|
||||
# minimal Zendesk list response (articles path)
|
||||
return _FakeResponse({"articles": [], "meta": {"has_more": False}})
|
||||
|
||||
monkeypatch.setattr(
|
||||
zendesk_mod, "requests", types.SimpleNamespace(get=_fake_get), raising=True
|
||||
)
|
||||
|
||||
# Build client with a small limit: 2 calls per 60 seconds
|
||||
client = ZendeskClient("subd", "e", "t", calls_per_minute=2)
|
||||
|
||||
# Make three calls in quick succession. The third should be rate limited
|
||||
client.make_request("help_center/articles", {"page[size]": 1})
|
||||
client.make_request("help_center/articles", {"page[size]": 1})
|
||||
|
||||
# At this point we've used up the 2 allowed calls within the 60s window
|
||||
# The next call should trigger sleeps with exponential backoff until >60s elapsed
|
||||
client.make_request("help_center/articles", {"page[size]": 1})
|
||||
|
||||
# Ensure we did not actually wait in real time but logically advanced beyond a minute
|
||||
assert fake_time.monotonic() >= 60
|
||||
# Ensure the HTTP function was invoked three times
|
||||
assert len(calls) == 3
|
||||
98
backend/tests/unit/onyx/utils/test_vespa_tasks.py
Normal file
98
backend/tests/unit/onyx/utils/test_vespa_tasks.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.vespa import tasks as vespa_tasks
|
||||
|
||||
|
||||
class _StubRedisDocumentSet:
|
||||
"""Lightweight stand-in for RedisDocumentSet used by monitor tests."""
|
||||
|
||||
reset_called = False
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
parts = key.split("_")
|
||||
return parts[-1] if len(parts) == 3 else None
|
||||
|
||||
def __init__(self, tenant_id: str, object_id: str) -> None:
|
||||
self.taskset_key = f"documentset_taskset_{object_id}"
|
||||
self._payload = 0
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def payload(self) -> int:
|
||||
return self._payload
|
||||
|
||||
def reset(self) -> None:
|
||||
self.__class__.reset_called = True
|
||||
|
||||
|
||||
def _setup_common_patches(monkeypatch: Any, document_set: Any) -> dict[str, bool]:
|
||||
calls: dict[str, bool] = {"deleted": False, "synced": False}
|
||||
|
||||
monkeypatch.setattr(vespa_tasks, "RedisDocumentSet", _StubRedisDocumentSet)
|
||||
|
||||
monkeypatch.setattr(
|
||||
vespa_tasks,
|
||||
"get_document_set_by_id",
|
||||
lambda db_session, document_set_id: document_set,
|
||||
)
|
||||
|
||||
def _delete(document_set_row: Any, db_session: Any) -> None:
|
||||
calls["deleted"] = True
|
||||
|
||||
monkeypatch.setattr(vespa_tasks, "delete_document_set", _delete)
|
||||
|
||||
def _mark(document_set_id: Any, db_session: Any) -> None:
|
||||
calls["synced"] = True
|
||||
|
||||
monkeypatch.setattr(vespa_tasks, "mark_document_set_as_synced", _mark)
|
||||
|
||||
monkeypatch.setattr(
|
||||
vespa_tasks,
|
||||
"update_sync_record_status",
|
||||
lambda db_session, entity_id, sync_type, sync_status, num_docs_synced: None,
|
||||
)
|
||||
|
||||
return calls
|
||||
|
||||
|
||||
def test_monitor_preserves_federated_only_document_set(monkeypatch: Any) -> None:
|
||||
document_set = SimpleNamespace(
|
||||
connector_credential_pairs=[],
|
||||
federated_connectors=[object()],
|
||||
)
|
||||
|
||||
calls = _setup_common_patches(monkeypatch, document_set)
|
||||
|
||||
vespa_tasks.monitor_document_set_taskset(
|
||||
tenant_id="tenant",
|
||||
key_bytes=b"documentset_fence_1",
|
||||
r=SimpleNamespace(scard=lambda key: 0), # type: ignore[arg-type]
|
||||
db_session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
assert calls["synced"] is True
|
||||
assert calls["deleted"] is False
|
||||
|
||||
|
||||
def test_monitor_deletes_document_set_with_no_connectors(monkeypatch: Any) -> None:
|
||||
document_set = SimpleNamespace(
|
||||
connector_credential_pairs=[],
|
||||
federated_connectors=[],
|
||||
)
|
||||
|
||||
calls = _setup_common_patches(monkeypatch, document_set)
|
||||
|
||||
vespa_tasks.monitor_document_set_taskset(
|
||||
tenant_id="tenant",
|
||||
key_bytes=b"documentset_fence_2",
|
||||
r=SimpleNamespace(scard=lambda key: 0), # type: ignore[arg-type]
|
||||
db_session=SimpleNamespace(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
assert calls["deleted"] is True
|
||||
assert calls["synced"] is False
|
||||
@@ -88,6 +88,7 @@ import { FilePickerModal } from "@/app/chat/my-documents/components/FilePicker";
|
||||
import { useDocumentsContext } from "@/app/chat/my-documents/DocumentsContext";
|
||||
|
||||
import {
|
||||
PYTHON_TOOL_ID,
|
||||
IMAGE_GENERATION_TOOL_ID,
|
||||
SEARCH_TOOL_ID,
|
||||
WEB_SEARCH_TOOL_ID,
|
||||
@@ -111,6 +112,10 @@ function findWebSearchTool(tools: ToolSnapshot[]) {
|
||||
return tools.find((tool) => tool.in_code_tool_id === WEB_SEARCH_TOOL_ID);
|
||||
}
|
||||
|
||||
function findPythonTool(tools: ToolSnapshot[]) {
|
||||
return tools.find((tool) => tool.in_code_tool_id === PYTHON_TOOL_ID);
|
||||
}
|
||||
|
||||
function SubLabel({ children }: { children: string | JSX.Element }) {
|
||||
return (
|
||||
<div
|
||||
@@ -213,13 +218,15 @@ export function AssistantEditor({
|
||||
const searchTool = findSearchTool(tools);
|
||||
const imageGenerationTool = findImageGenerationTool(tools);
|
||||
const webSearchTool = findWebSearchTool(tools);
|
||||
const pythonTool = findPythonTool(tools);
|
||||
|
||||
// Separate MCP tools from regular custom tools
|
||||
const allCustomTools = tools.filter(
|
||||
(tool) =>
|
||||
tool.in_code_tool_id !== searchTool?.in_code_tool_id &&
|
||||
tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id &&
|
||||
tool.in_code_tool_id !== webSearchTool?.in_code_tool_id
|
||||
tool.in_code_tool_id !== webSearchTool?.in_code_tool_id &&
|
||||
tool.in_code_tool_id !== pythonTool?.in_code_tool_id
|
||||
);
|
||||
|
||||
const mcpTools = allCustomTools.filter((tool) => tool.mcp_server_id);
|
||||
@@ -290,6 +297,7 @@ export function AssistantEditor({
|
||||
...mcpTools, // Include MCP tools for form logic
|
||||
...(searchTool ? [searchTool] : []),
|
||||
...(imageGenerationTool ? [imageGenerationTool] : []),
|
||||
...(pythonTool ? [pythonTool] : []),
|
||||
...(webSearchTool ? [webSearchTool] : []),
|
||||
];
|
||||
const enabledToolsMap: { [key: number]: boolean } = {};
|
||||
@@ -1211,6 +1219,16 @@ export function AssistantEditor({
|
||||
</>
|
||||
)}
|
||||
|
||||
{pythonTool && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
name={`enabled_tools_map.${pythonTool.id}`}
|
||||
label={pythonTool.display_name}
|
||||
subtext="Execute Python code against staged files using the Code Interpreter sandbox"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{webSearchTool && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
|
||||
@@ -272,6 +272,9 @@ function ToolToggle({
|
||||
if (tool.in_code_tool_id === "ImageGenerationTool") {
|
||||
return "Add an OpenAI LLM provider with an API key under Admin → Configuration → LLM.";
|
||||
}
|
||||
if (tool.in_code_tool_id === "PythonTool") {
|
||||
return "Set CODE_INTERPRETER_BASE_URL on the server and restart to enable Code Interpreter.";
|
||||
}
|
||||
return "Not configured.";
|
||||
})();
|
||||
return (
|
||||
|
||||
@@ -59,6 +59,7 @@ import { CreateStdOAuthCredential } from "@/components/credentials/actions/Creat
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { deleteConnector } from "@/lib/connector";
|
||||
import ConnectorDocsLink from "@/components/admin/connectors/ConnectorDocsLink";
|
||||
|
||||
export interface AdvancedConfig {
|
||||
refreshFreq: number;
|
||||
@@ -640,6 +641,7 @@ export default function AddConnector({
|
||||
null
|
||||
}
|
||||
/>
|
||||
<ConnectorDocsLink sourceType={connector} />
|
||||
</CardSection>
|
||||
)}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import { FiImage, FiSearch } from "react-icons/fi";
|
||||
import { FiImage, FiSearch, FiTerminal } from "react-icons/fi";
|
||||
import { Persona } from "../admin/assistants/interfaces";
|
||||
import { SEARCH_TOOL_ID } from "../chat/components/tools/constants";
|
||||
import {
|
||||
PYTHON_TOOL_ID,
|
||||
SEARCH_TOOL_ID,
|
||||
} from "../chat/components/tools/constants";
|
||||
|
||||
export function AssistantTools({
|
||||
assistant,
|
||||
@@ -69,6 +72,26 @@ export function AssistantTools({
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
} else if (tool.name === PYTHON_TOOL_ID) {
|
||||
return (
|
||||
<div
|
||||
key={ind}
|
||||
className={`
|
||||
px-1.5
|
||||
py-1
|
||||
rounded-lg
|
||||
border
|
||||
border-border
|
||||
w-fit
|
||||
flex
|
||||
${list ? "bg-background-125" : "bg-background-100"}`}
|
||||
>
|
||||
<div className="flex items-center gap-x-1">
|
||||
<FiTerminal className="ml-1 my-auto h-3 w-3" />
|
||||
{tool.display_name || "Code Interpreter"}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<div
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
export const SEARCH_TOOL_NAME = "run_search";
|
||||
export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search";
|
||||
export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation";
|
||||
export const PYTHON_TOOL_NAME = "run_code_interpreter";
|
||||
|
||||
// In-code tool IDs that also correspond to the tool's name when associated with a persona
|
||||
export const SEARCH_TOOL_ID = "SearchTool";
|
||||
export const IMAGE_GENERATION_TOOL_ID = "ImageGenerationTool";
|
||||
export const WEB_SEARCH_TOOL_ID = "WebSearchTool";
|
||||
export const PYTHON_TOOL_ID = "PythonTool";
|
||||
|
||||
@@ -32,6 +32,14 @@ const isImageGenerationTool = (tool: ToolSnapshot): boolean => {
|
||||
);
|
||||
};
|
||||
|
||||
const isPythonTool = (tool: ToolSnapshot): boolean => {
|
||||
return (
|
||||
tool.in_code_tool_id === "PythonTool" ||
|
||||
tool.in_code_tool_id === "AnalysisTool" ||
|
||||
tool.display_name?.toLowerCase().includes("code interpreter")
|
||||
);
|
||||
};
|
||||
|
||||
const isKnowledgeGraphTool = (tool: ToolSnapshot): boolean => {
|
||||
return (
|
||||
tool.in_code_tool_id === "KnowledgeGraphTool" ||
|
||||
@@ -55,6 +63,8 @@ export function getIconForAction(
|
||||
return GlobeIcon;
|
||||
} else if (isImageGenerationTool(action)) {
|
||||
return ImageIcon;
|
||||
} else if (isPythonTool(action)) {
|
||||
return CpuIcon;
|
||||
} else if (isKnowledgeGraphTool(action)) {
|
||||
return DatabaseIcon;
|
||||
} else if (isOktaProfileTool(action)) {
|
||||
|
||||
34
web/src/components/admin/connectors/ConnectorDocsLink.tsx
Normal file
34
web/src/components/admin/connectors/ConnectorDocsLink.tsx
Normal file
@@ -0,0 +1,34 @@
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { getSourceDocLink } from "@/lib/sources";
|
||||
|
||||
export default function ConnectorDocsLink({
|
||||
sourceType,
|
||||
className,
|
||||
}: {
|
||||
sourceType: ValidSources;
|
||||
className?: string;
|
||||
}) {
|
||||
const docsLink = getSourceDocLink(sourceType);
|
||||
|
||||
if (!docsLink) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const paragraphClass = ["text-sm", className].filter(Boolean).join(" ");
|
||||
|
||||
return (
|
||||
<p className={paragraphClass}>
|
||||
Check out
|
||||
<a
|
||||
className="text-blue-600 hover:underline"
|
||||
target="_blank"
|
||||
rel="noopener"
|
||||
href={docsLink}
|
||||
>
|
||||
{" "}
|
||||
our docs{" "}
|
||||
</a>
|
||||
for more info on configuring this connector.
|
||||
</p>
|
||||
);
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import { submitCredential } from "@/components/admin/connectors/CredentialForm";
|
||||
import { TextFormField } from "@/components/Field";
|
||||
import { Form, Formik, FormikHelpers } from "formik";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { getSourceDocLink } from "@/lib/sources";
|
||||
import GDriveMain from "@/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage";
|
||||
import { Connector } from "@/lib/connectors/connectors";
|
||||
import { Credential, credentialTemplates } from "@/lib/connectors/credentials";
|
||||
@@ -24,6 +23,7 @@ import { useUser } from "@/components/user/UserProvider";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { CredentialFieldsRenderer } from "./CredentialFieldsRenderer";
|
||||
import { TypedFile } from "@/lib/connectors/fileTypes";
|
||||
import ConnectorDocsLink from "@/components/admin/connectors/ConnectorDocsLink";
|
||||
|
||||
const CreateButton = ({
|
||||
onClick,
|
||||
@@ -213,20 +213,7 @@ export default function CreateCredential({
|
||||
|
||||
return (
|
||||
<Form className="w-full flex items-stretch">
|
||||
{!hideSource && (
|
||||
<p className="text-sm">
|
||||
Check our
|
||||
<a
|
||||
className="text-blue-600 hover:underline"
|
||||
target="_blank"
|
||||
href={getSourceDocLink(sourceType) || ""}
|
||||
>
|
||||
{" "}
|
||||
docs{" "}
|
||||
</a>
|
||||
for information on setting up this connector.
|
||||
</p>
|
||||
)}
|
||||
{!hideSource && <ConnectorDocsLink sourceType={sourceType} />}
|
||||
<CardSection className="w-full items-start dark:bg-neutral-900 mt-4 flex flex-col gap-y-6">
|
||||
<TextFormField
|
||||
name="name"
|
||||
|
||||
@@ -1071,7 +1071,16 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
|
||||
default: "articles",
|
||||
},
|
||||
],
|
||||
advanced_values: [],
|
||||
advanced_values: [
|
||||
{
|
||||
type: "number",
|
||||
label: "API Calls per Minute",
|
||||
name: "calls_per_minute",
|
||||
optional: true,
|
||||
description:
|
||||
"Restricts how many Zendesk API calls this connector can make per minute (applies only to this connector). See defaults: https://developer.zendesk.com/api-reference/introduction/rate-limits/",
|
||||
},
|
||||
],
|
||||
},
|
||||
linear: {
|
||||
description: "Configure Linear connector",
|
||||
@@ -1770,7 +1779,10 @@ export interface XenforoConfig {
|
||||
base_url: string;
|
||||
}
|
||||
|
||||
export interface ZendeskConfig {}
|
||||
export interface ZendeskConfig {
|
||||
content_type?: "articles" | "tickets";
|
||||
calls_per_minute?: number;
|
||||
}
|
||||
|
||||
export interface DropboxConfig {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user