mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-15 20:52:39 +00:00
Compare commits
2 Commits
v3.0.0
...
tool_searc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9633152f4d | ||
|
|
0370023e08 |
@@ -81,6 +81,7 @@ from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.tool_implementations.custom.models import CustomToolResponseType
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
@@ -160,7 +161,6 @@ def _handle_search_tool_response_summary(
|
||||
]
|
||||
else:
|
||||
reference_db_search_docs = selected_search_docs
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
@@ -603,6 +603,8 @@ def stream_chat_message_objects(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
),
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
@@ -659,6 +661,7 @@ def stream_chat_message_objects(
|
||||
else False
|
||||
),
|
||||
)
|
||||
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
@@ -715,15 +718,17 @@ def stream_chat_message_objects(
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
|
||||
if (
|
||||
custom_tool_response.response_type == "image"
|
||||
or custom_tool_response.response_type == "csv"
|
||||
custom_tool_response.response_type == CustomToolResponseType.CSV
|
||||
or custom_tool_response.response_type
|
||||
== CustomToolResponseType.IMAGE
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files = [
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
if custom_tool_response.response_type
|
||||
== CustomToolResponseType.IMAGE
|
||||
else ChatFileType.CSV,
|
||||
)
|
||||
for file_id in file_ids
|
||||
|
||||
@@ -88,6 +88,7 @@ DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
CUSTOM_TOOL = "custom_tool"
|
||||
INGESTION_API = "ingestion_api"
|
||||
SLACK = "slack"
|
||||
WEB = "web"
|
||||
|
||||
@@ -125,6 +125,8 @@ class CustomToolConfig(BaseModel):
|
||||
chat_session_id: UUID | None = None
|
||||
message_id: int | None = None
|
||||
additional_headers: dict[str, str] | None = None
|
||||
answer_style_config: AnswerStyleConfig | None = None
|
||||
prompt_config: PromptConfig | None = None
|
||||
|
||||
|
||||
def construct_tools(
|
||||
@@ -228,6 +230,8 @@ def construct_tools(
|
||||
custom_tool_config.additional_headers or {}
|
||||
)
|
||||
),
|
||||
answer_style_config=custom_tool_config.answer_style_config,
|
||||
prompt_config=custom_tool_config.prompt_config,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -12,17 +12,25 @@ from typing import List
|
||||
import requests
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_section
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import SectionRelevancePiece
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.search.models import DocumentSource
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.tools.base_tool import BaseTool
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
|
||||
@@ -42,6 +50,12 @@ from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
|
||||
TOOL_ARG_USER_PROMPT,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import USE_TOOL
|
||||
from danswer.tools.tool_implementations.custom.models import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.tool_implementations.custom.models import CustomToolCallSummary
|
||||
from danswer.tools.tool_implementations.custom.models import CustomToolFileResponse
|
||||
from danswer.tools.tool_implementations.custom.models import CustomToolResponseType
|
||||
from danswer.tools.tool_implementations.custom.models import CustomToolSearchResponse
|
||||
from danswer.tools.tool_implementations.custom.models import CustomToolSearchResult
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import (
|
||||
openapi_to_method_specs,
|
||||
@@ -51,35 +65,39 @@ from danswer.tools.tool_implementations.custom.openapi_parsing import REQUEST_BO
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import (
|
||||
validate_openapi_schema,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.prompt import (
|
||||
build_custom_image_generation_user_prompt,
|
||||
from danswer.tools.tool_implementations.file_like_tool_utils import (
|
||||
build_next_prompt_for_file_like_tool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
build_next_prompt_for_search_like_tool,
|
||||
)
|
||||
from danswer.utils.headers import header_list_to_header_dict
|
||||
from danswer.utils.headers import HeaderItemDict
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response"
|
||||
|
||||
|
||||
class CustomToolFileResponse(BaseModel):
|
||||
file_ids: List[str] # References to saved images or CSVs
|
||||
|
||||
|
||||
class CustomToolCallSummary(BaseModel):
|
||||
tool_name: str
|
||||
response_type: str # e.g., 'json', 'image', 'csv', 'graph'
|
||||
tool_result: Any # The response data
|
||||
|
||||
|
||||
class CustomTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
method_spec: MethodSpec,
|
||||
base_url: str,
|
||||
answer_style_config: AnswerStyleConfig | None = None,
|
||||
custom_headers: list[HeaderItemDict] | None = None,
|
||||
prompt_config: PromptConfig | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url
|
||||
self._method_spec = method_spec
|
||||
@@ -90,6 +108,8 @@ class CustomTool(BaseTool):
|
||||
self.headers = (
|
||||
header_list_to_header_dict(custom_headers) if custom_headers else {}
|
||||
)
|
||||
self.answer_style_config = answer_style_config
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -111,14 +131,44 @@ class CustomTool(BaseTool):
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
response = cast(CustomToolCallSummary, args[0].response)
|
||||
final_context_docs_response = next(
|
||||
(
|
||||
response
|
||||
for response in args
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS_ID
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if response.response_type == "image" or response.response_type == "csv":
|
||||
image_response = cast(CustomToolFileResponse, response.tool_result)
|
||||
return json.dumps({"file_ids": image_response.file_ids})
|
||||
# Handle the search type response
|
||||
if final_context_docs_response:
|
||||
final_context_docs = cast(
|
||||
list[LlmDoc], final_context_docs_response.response
|
||||
)
|
||||
return json.dumps(
|
||||
{
|
||||
"search_results": [
|
||||
llm_doc_to_dict(doc, ind)
|
||||
for ind, doc in enumerate(final_context_docs)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# For JSON or other responses, return as-is
|
||||
return json.dumps(response.tool_result)
|
||||
# Handle other response types
|
||||
response = args[0].response
|
||||
if isinstance(response, CustomToolCallSummary):
|
||||
if response.response_type in [
|
||||
CustomToolResponseType.IMAGE,
|
||||
CustomToolResponseType.CSV,
|
||||
]:
|
||||
file_response = cast(CustomToolFileResponse, response.tool_result)
|
||||
return json.dumps({"file_ids": file_response.file_ids})
|
||||
|
||||
# For JSON or other responses, return as-is
|
||||
return json.dumps(response.tool_result)
|
||||
|
||||
# If it's not a CustomToolCallSummary or search result, return as-is
|
||||
return json.dumps(response)
|
||||
|
||||
"""For LLMs which do NOT support explicit tool calling"""
|
||||
|
||||
@@ -219,6 +269,75 @@ class CustomTool(BaseTool):
|
||||
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
def _handle_search_like_tool_response(
|
||||
self, tool_result: CustomToolSearchResponse
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
# Convert results to InferenceSections
|
||||
inference_sections = []
|
||||
for result in tool_result.results:
|
||||
chunk = InferenceChunk(
|
||||
document_id=result.document_id,
|
||||
content=result.content,
|
||||
semantic_identifier=result.document_id, # using document_id as semantic_identifier
|
||||
blurb=result.blurb,
|
||||
source_type=DocumentSource.CUSTOM_TOOL,
|
||||
# Use defaults
|
||||
chunk_id=0,
|
||||
source_links={},
|
||||
section_continuation=False,
|
||||
title=result.title,
|
||||
boost=0,
|
||||
recency_bias=0, # Default recency bias
|
||||
hidden=False,
|
||||
score=0,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
updated_at=result.updated_at,
|
||||
)
|
||||
|
||||
# We assume that each search result belongs to different documents
|
||||
section = InferenceSection(
|
||||
center_chunk=chunk,
|
||||
chunks=[chunk],
|
||||
combined_content=chunk.content,
|
||||
)
|
||||
|
||||
inference_sections.append(section)
|
||||
|
||||
search_response_summary = SearchResponseSummary(
|
||||
rephrased_query=None,
|
||||
top_sections=inference_sections,
|
||||
predicted_flow=None,
|
||||
predicted_search=None,
|
||||
final_filters=IndexFilters(access_control_list=None),
|
||||
recency_bias_multiplier=0.0,
|
||||
)
|
||||
|
||||
yield ToolResponse(
|
||||
id=SEARCH_RESPONSE_SUMMARY_ID,
|
||||
response=search_response_summary,
|
||||
)
|
||||
# Build selected sections for relevance (assuming all are relevant)
|
||||
selected_sections = [
|
||||
SectionRelevancePiece(
|
||||
relevant=True,
|
||||
document_id=section.center_chunk.document_id,
|
||||
chunk_id=0,
|
||||
)
|
||||
for section in inference_sections
|
||||
]
|
||||
|
||||
yield ToolResponse(
|
||||
id=SECTION_RELEVANCE_LIST_ID,
|
||||
response=selected_sections,
|
||||
)
|
||||
|
||||
llm_docs = [
|
||||
llm_doc_from_inference_section(section) for section in inference_sections
|
||||
]
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
request_body = kwargs.get(REQUEST_BODY)
|
||||
|
||||
@@ -243,44 +362,67 @@ class CustomTool(BaseTool):
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
|
||||
tool_result: Any
|
||||
response_type: str
|
||||
response_type: CustomToolResponseType
|
||||
if "text/csv" in content_type:
|
||||
file_ids = self._save_and_get_file_references(
|
||||
response.content, content_type
|
||||
)
|
||||
tool_result = CustomToolFileResponse(file_ids=file_ids)
|
||||
response_type = "csv"
|
||||
response_type = CustomToolResponseType.CSV
|
||||
|
||||
elif "image/" in content_type:
|
||||
file_ids = self._save_and_get_file_references(
|
||||
response.content, content_type
|
||||
)
|
||||
tool_result = CustomToolFileResponse(file_ids=file_ids)
|
||||
response_type = "image"
|
||||
response_type = CustomToolResponseType.IMAGE
|
||||
|
||||
elif content_type == "application/json":
|
||||
tool_result = response.json()
|
||||
|
||||
# Check if the response is a search result
|
||||
if isinstance(tool_result, list) and all(
|
||||
"content" in item for item in tool_result
|
||||
):
|
||||
# Process as search results
|
||||
search_results = [
|
||||
CustomToolSearchResult(**item) for item in tool_result
|
||||
]
|
||||
tool_result = CustomToolSearchResponse(results=search_results)
|
||||
response_type = CustomToolResponseType.SEARCH
|
||||
else:
|
||||
# Process as generic JSON
|
||||
response_type = CustomToolResponseType.JSON
|
||||
|
||||
else:
|
||||
# Default to JSON if content type is not specified
|
||||
try:
|
||||
tool_result = response.json()
|
||||
response_type = "json"
|
||||
response_type = CustomToolResponseType.JSON
|
||||
except JSONDecodeError:
|
||||
logger.exception(
|
||||
f"Failed to parse response as JSON for tool '{self._name}'"
|
||||
)
|
||||
tool_result = response.text
|
||||
response_type = "text"
|
||||
response_type = CustomToolResponseType.TEXT
|
||||
|
||||
logger.info(
|
||||
f"Returning tool response for {self._name} with type {response_type}"
|
||||
)
|
||||
if response_type == CustomToolResponseType.SEARCH and isinstance(
|
||||
tool_result, CustomToolSearchResponse
|
||||
):
|
||||
yield from self._handle_search_like_tool_response(tool_result)
|
||||
|
||||
yield ToolResponse(
|
||||
id=CUSTOM_TOOL_RESPONSE_ID,
|
||||
response=CustomToolCallSummary(
|
||||
tool_name=self._name,
|
||||
response_type=response_type,
|
||||
tool_result=tool_result,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield ToolResponse(
|
||||
id=CUSTOM_TOOL_RESPONSE_ID,
|
||||
response=CustomToolCallSummary(
|
||||
tool_name=self._name,
|
||||
response_type=response_type,
|
||||
tool_result=tool_result,
|
||||
),
|
||||
)
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
@@ -289,10 +431,26 @@ class CustomTool(BaseTool):
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> AnswerPromptBuilder:
|
||||
response = cast(CustomToolCallSummary, tool_responses[0].response)
|
||||
response = tool_responses[0].response
|
||||
|
||||
if isinstance(response, SearchResponseSummary) and self.prompt_config:
|
||||
if not self.answer_style_config or self.answer_style_config.citation_config:
|
||||
raise ValueError("Citation config is required for search tools")
|
||||
|
||||
return build_next_prompt_for_search_like_tool(
|
||||
prompt_builder=prompt_builder,
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_responses=tool_responses,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
answer_style_config=self.answer_style_config,
|
||||
prompt_config=self.prompt_config,
|
||||
)
|
||||
|
||||
# Handle non-file responses using parent class behavior
|
||||
if response.response_type not in ["image", "csv"]:
|
||||
if response.response_type not in [
|
||||
CustomToolResponseType.IMAGE,
|
||||
CustomToolResponseType.CSV,
|
||||
]:
|
||||
return super().build_next_prompt(
|
||||
prompt_builder,
|
||||
tool_call_summary,
|
||||
@@ -300,54 +458,47 @@ class CustomTool(BaseTool):
|
||||
using_tool_calling_llm,
|
||||
)
|
||||
|
||||
# Handle image and CSV file responses
|
||||
# Handle file responses
|
||||
file_type = (
|
||||
ChatFileType.IMAGE
|
||||
if response.response_type == "image"
|
||||
if response.response_type == CustomToolResponseType.IMAGE
|
||||
else ChatFileType.CSV
|
||||
)
|
||||
|
||||
# Load files from storage
|
||||
files = []
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
for file_id in response.tool_result.file_ids:
|
||||
try:
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
files.append(
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
filename=file_id,
|
||||
content=file_io.read(),
|
||||
file_type=file_type,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to read file {file_id}")
|
||||
|
||||
# Update prompt with file content
|
||||
prompt_builder.update_user_prompt(
|
||||
build_custom_image_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
files=files,
|
||||
file_type=file_type,
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_builder
|
||||
return build_next_prompt_for_file_like_tool(
|
||||
prompt_builder,
|
||||
response.tool_result.file_ids,
|
||||
file_type=file_type,
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
response = cast(CustomToolCallSummary, args[0].response)
|
||||
if isinstance(response.tool_result, CustomToolFileResponse):
|
||||
return response.tool_result.model_dump()
|
||||
return response.tool_result
|
||||
if hasattr(response, "tool_result"):
|
||||
if isinstance(response.tool_result, CustomToolFileResponse):
|
||||
return response.tool_result.model_dump()
|
||||
|
||||
return response.tool_result
|
||||
else:
|
||||
final_docs = cast(
|
||||
list[LlmDoc],
|
||||
next(
|
||||
arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS_ID
|
||||
),
|
||||
)
|
||||
# NOTE: need to do this json.loads(doc.json()) stuff because there are some
|
||||
# subfields that are not serializable by default (datetime)
|
||||
# this forces pydantic to make them JSON serializable for us
|
||||
return [json.loads(doc.model_dump_json()) for doc in final_docs]
|
||||
|
||||
# return response.tool_result
|
||||
|
||||
|
||||
def build_custom_tools_from_openapi_schema_and_headers(
|
||||
openapi_schema: dict[str, Any],
|
||||
answer_style_config: AnswerStyleConfig | None = None,
|
||||
custom_headers: list[HeaderItemDict] | None = None,
|
||||
dynamic_schema_info: DynamicSchemaInfo | None = None,
|
||||
prompt_config: PromptConfig | None = None,
|
||||
) -> list[CustomTool]:
|
||||
if dynamic_schema_info:
|
||||
# Process dynamic schema information
|
||||
@@ -366,7 +517,8 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
return [
|
||||
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
|
||||
CustomTool(method_spec, url, answer_style_config, custom_headers, prompt_config)
|
||||
for method_spec in method_specs
|
||||
]
|
||||
|
||||
|
||||
|
||||
39
backend/danswer/tools/tool_implementations/custom/models.py
Normal file
39
backend/danswer/tools/tool_implementations/custom/models.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response"
|
||||
|
||||
|
||||
class CustomToolResponseType(Enum):
|
||||
IMAGE = "image"
|
||||
CSV = "csv"
|
||||
JSON = "json"
|
||||
TEXT = "text"
|
||||
SEARCH = "search"
|
||||
|
||||
|
||||
class CustomToolFileResponse(BaseModel):
|
||||
file_ids: List[str] # References to saved images or CSVs
|
||||
|
||||
|
||||
class CustomToolCallSummary(BaseModel):
|
||||
tool_name: str
|
||||
response_type: CustomToolResponseType
|
||||
tool_result: Any # The response data
|
||||
|
||||
|
||||
class CustomToolSearchResult(BaseModel):
|
||||
document_id: str
|
||||
content: str # content of the search result
|
||||
blurb: str # used to display in the UI
|
||||
title: str
|
||||
link: str | None = None # optional source link
|
||||
updated_at: datetime | None = None
|
||||
|
||||
|
||||
class CustomToolSearchResponse(BaseModel):
|
||||
results: List[CustomToolSearchResult]
|
||||
@@ -12,7 +12,7 @@ Can you please summarize it in a sentence or two? Do NOT include image urls or b
|
||||
"""
|
||||
|
||||
|
||||
def build_custom_image_generation_user_prompt(
|
||||
def build_custom_file_generation_user_prompt(
|
||||
query: str, file_type: ChatFileType, files: list[InMemoryChatFile] | None = None
|
||||
) -> HumanMessage:
|
||||
return HumanMessage(
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
IMG_GENERATION_SUMMARY_PROMPT = """
|
||||
You have just created the attached images in response to the following query: "{query}".
|
||||
|
||||
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
|
||||
"""
|
||||
|
||||
|
||||
def build_file_generation_user_prompt(
|
||||
query: str, files: list[InMemoryChatFile] | None = None
|
||||
) -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
||||
files=files,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_next_prompt_for_file_like_tool(
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
file_ids: list[str],
|
||||
file_type: ChatFileType,
|
||||
) -> AnswerPromptBuilder:
|
||||
with get_session_with_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
files = []
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
files.append(
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
filename=file_id,
|
||||
content=file_io.read(),
|
||||
file_type=file_type,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to read file {file_id}")
|
||||
|
||||
# Update prompt with file content
|
||||
prompt_builder.update_user_prompt(
|
||||
build_file_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
files=files,
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_builder
|
||||
153
backend/tests/integration/common_utils/managers/tool.py
Normal file
153
backend/tests/integration/common_utils/managers/tool.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestTool
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class ToolManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
definition: dict | None = None,
|
||||
custom_headers: list[dict[str, str]] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestTool:
|
||||
name = name or f"test-tool-{uuid4()}"
|
||||
description = description or f"Description for {name}"
|
||||
definition = definition or {}
|
||||
|
||||
# Validate the tool definition before creating
|
||||
validate_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/tool/custom/validate",
|
||||
json={"definition": definition},
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
if validate_response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Tool validation failed: {validate_response.json().get('detail', 'Unknown error')}"
|
||||
)
|
||||
|
||||
tool_creation_request = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"definition": definition,
|
||||
}
|
||||
|
||||
if custom_headers is not None:
|
||||
tool_creation_request["custom_headers"] = custom_headers
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/tool/custom",
|
||||
json=tool_creation_request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
if response.status_code == 400:
|
||||
raise Exception(
|
||||
f"Tool creation failed: {response.json().get('detail', 'Unknown error')}"
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
tool_data = response.json()
|
||||
|
||||
return DATestTool(
|
||||
id=tool_data["id"],
|
||||
name=tool_data["name"],
|
||||
description=tool_data["description"],
|
||||
definition=tool_data["definition"],
|
||||
custom_headers=tool_data.get("custom_headers", []),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
tool: DATestTool,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
definition: dict | None = None,
|
||||
custom_headers: dict[str, str] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestTool:
|
||||
tool_update_request = {
|
||||
"name": name or tool.name,
|
||||
"description": description or tool.description,
|
||||
"definition": definition or tool.definition,
|
||||
"custom_headers": custom_headers or tool.custom_headers,
|
||||
}
|
||||
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/tool/custom/{tool.id}",
|
||||
json=tool_update_request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
updated_tool_data = response.json()
|
||||
|
||||
return DATestTool(
|
||||
id=updated_tool_data["id"],
|
||||
name=updated_tool_data["name"],
|
||||
description=updated_tool_data["description"],
|
||||
definition=updated_tool_data["definition"],
|
||||
custom_headers=updated_tool_data.get("custom_headers", []),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[DATestTool]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/tool",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [
|
||||
DATestTool(
|
||||
id=tool["id"],
|
||||
name=tool["name"],
|
||||
description=tool.get("description", ""),
|
||||
definition=tool.get("definition") or {},
|
||||
custom_headers=tool.get("custom_headers") or [],
|
||||
)
|
||||
for tool in response.json()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
tool: DATestTool,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
all_tools = ToolManager.get_all(user_performing_action)
|
||||
for fetched_tool in all_tools:
|
||||
if fetched_tool.id == tool.id:
|
||||
return (
|
||||
fetched_tool.name == tool.name
|
||||
and fetched_tool.description == tool.description
|
||||
and fetched_tool.definition == tool.definition
|
||||
and fetched_tool.custom_headers == tool.custom_headers
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
tool: DATestTool,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/tool/custom/{tool.id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
return response.ok
|
||||
@@ -150,3 +150,16 @@ class StreamedResponse(BaseModel):
|
||||
relevance_summaries: list[dict[str, Any]] | None = None
|
||||
tool_result: Any | None = None
|
||||
user: str | None = None
|
||||
|
||||
|
||||
class DATestToolHeader(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
class DATestTool(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
definition: dict[str, Any]
|
||||
custom_headers: list[DATestToolHeader]
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import requests_mock
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
|
||||
build_and_run_tool,
|
||||
)
|
||||
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
|
||||
create_and_verify_custom_tool,
|
||||
)
|
||||
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
|
||||
process_tool_responses,
|
||||
)
|
||||
|
||||
# Absolute imports for helper functions
|
||||
|
||||
|
||||
def test_custom_csv_tool(reset: None) -> None:
|
||||
# Create admin user directly in the test
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Define the tool using the provided OpenAPI schema for CSV export
|
||||
tool_definition = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": "Sample CSV Export API",
|
||||
"version": "1.0.0",
|
||||
"description": "An API for exporting data in CSV format",
|
||||
},
|
||||
"paths": {
|
||||
"/export": {
|
||||
"get": {
|
||||
"summary": "Export data in CSV format",
|
||||
"operationId": "exportCsv",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful CSV export",
|
||||
"content": {
|
||||
"text/csv": {
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"format": "binary",
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"parameters": [
|
||||
{
|
||||
"in": "query",
|
||||
"name": "date",
|
||||
"schema": {"type": "string", "format": "date"},
|
||||
"required": True,
|
||||
"description": "The date for which to export data",
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
},
|
||||
"servers": [{"url": "http://localhost:8000"}],
|
||||
}
|
||||
|
||||
create_and_verify_custom_tool(
|
||||
tool_definition=tool_definition,
|
||||
user=admin_user,
|
||||
name="Sample CSV Export Tool",
|
||||
description="A sample tool that exports data in CSV format",
|
||||
)
|
||||
|
||||
mock_csv_content = "id,name,age\n1,Alice,30\n2,Bob,25\n3,Charlie,35\n"
|
||||
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(
|
||||
"http://localhost:8000/export",
|
||||
text=mock_csv_content,
|
||||
headers={"Content-Type": "text/csv"},
|
||||
)
|
||||
|
||||
# Prepare arguments for the tool
|
||||
tool_args = {"date": "2023-10-01"}
|
||||
|
||||
# Build and run the tool
|
||||
tool_responses = build_and_run_tool(tool_definition, tool_args)
|
||||
|
||||
# Process the tool responses
|
||||
process_tool_responses(tool_responses)
|
||||
|
||||
print("Custom CSV tool ran successfully with processed response")
|
||||
@@ -0,0 +1,97 @@
|
||||
import requests_mock
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
|
||||
build_and_run_tool,
|
||||
)
|
||||
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
|
||||
create_and_verify_custom_tool,
|
||||
)
|
||||
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
|
||||
process_tool_responses,
|
||||
)
|
||||
|
||||
# Absolute imports for helper functions
|
||||
|
||||
|
||||
def test_custom_image_tool(reset: None) -> None:
|
||||
# Create admin user directly in the test
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Define the tool using the provided OpenAPI schema for image fetching
|
||||
tool_definition = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": "Sample Image Fetch API",
|
||||
"version": "1.0.0",
|
||||
"description": "An API for fetching images",
|
||||
},
|
||||
"paths": {
|
||||
"/image/{image_id}": {
|
||||
"get": {
|
||||
"summary": "Fetch an image",
|
||||
"operationId": "fetchImage",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful image fetch",
|
||||
"content": {
|
||||
"image/png": {
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"format": "binary",
|
||||
}
|
||||
},
|
||||
"image/jpeg": {
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"format": "binary",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"parameters": [
|
||||
{
|
||||
"in": "path",
|
||||
"name": "image_id",
|
||||
"schema": {"type": "string"},
|
||||
"required": True,
|
||||
"description": "The ID of the image to fetch",
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
},
|
||||
"servers": [{"url": "http://localhost:8000"}],
|
||||
}
|
||||
|
||||
# Create and verify the custom tool
|
||||
create_and_verify_custom_tool(
|
||||
tool_definition=tool_definition,
|
||||
user=admin_user,
|
||||
name="Sample Image Fetch Tool",
|
||||
description="A sample tool that fetches an image",
|
||||
)
|
||||
|
||||
# Prepare mock image data for the API response
|
||||
mock_image_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01..."
|
||||
|
||||
# Mock the requests made by the tool
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(
|
||||
"http://localhost:8000/image/123",
|
||||
content=mock_image_content,
|
||||
headers={"Content-Type": "image/png"},
|
||||
)
|
||||
|
||||
# Prepare arguments for the tool
|
||||
tool_args = {"image_id": "123"}
|
||||
|
||||
# Build and run the tool
|
||||
tool_responses = build_and_run_tool(tool_definition, tool_args)
|
||||
|
||||
# Process the tool responses
|
||||
process_tool_responses(tool_responses)
|
||||
|
||||
print("Custom image tool ran successfully with processed response")
|
||||
@@ -0,0 +1,180 @@
|
||||
import datetime
|
||||
from typing import cast
|
||||
|
||||
import requests_mock
|
||||
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
|
||||
build_and_run_tool,
|
||||
)
|
||||
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
|
||||
create_and_verify_custom_tool,
|
||||
)
|
||||
|
||||
|
||||
def test_custom_search_tool(reset: None) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
tool_definition = {
|
||||
"info": {
|
||||
"title": "Sample Search API",
|
||||
"version": "1.0.0",
|
||||
"description": "An API for performing searches and returning sample results",
|
||||
},
|
||||
"paths": {
|
||||
"/search": {
|
||||
"get": {
|
||||
"summary": "Perform a search",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/CustomToolSearchResult"
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"parameters": [
|
||||
{
|
||||
"in": "query",
|
||||
"name": "query",
|
||||
"schema": {"type": "string"},
|
||||
"required": True,
|
||||
"description": "The search query",
|
||||
}
|
||||
],
|
||||
"operationId": "searchApi",
|
||||
}
|
||||
}
|
||||
},
|
||||
"openapi": "3.0.0",
|
||||
"servers": [{"url": "http://localhost:8000"}],
|
||||
"components": {
|
||||
"schemas": {
|
||||
"CustomToolSearchResult": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"content",
|
||||
"document_id",
|
||||
"semantic_identifier",
|
||||
"blurb",
|
||||
"source_type",
|
||||
"score",
|
||||
],
|
||||
"properties": {
|
||||
"content": {"type": "string"},
|
||||
"document_id": {"type": "string"},
|
||||
"semantic_identifier": {"type": "string"},
|
||||
"blurb": {"type": "string"},
|
||||
"source_type": {"type": "string"},
|
||||
"score": {"type": "number", "format": "float"},
|
||||
"link": {"type": "string", "nullable": True},
|
||||
"title": {"type": "string", "nullable": True},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"nullable": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
create_and_verify_custom_tool(
|
||||
tool_definition=tool_definition,
|
||||
user=admin_user,
|
||||
name="Sample Search Tool",
|
||||
description="A sample tool that performs a search",
|
||||
)
|
||||
|
||||
mock_results = [
|
||||
{
|
||||
"title": "Dog Article",
|
||||
"content": """Dogs are known for their loyalty and affection towards humans.
|
||||
They come in various breeds, each with unique characteristics.""",
|
||||
"document_id": "dog1",
|
||||
"semantic_identifier": "Dog Basics",
|
||||
"blurb": "An overview of dogs as pets and their general traits.",
|
||||
"source_type": "Pet Information",
|
||||
"score": 0.95,
|
||||
"link": "http://example.com/dogs/basics",
|
||||
"updated_at": datetime.datetime.now().isoformat(),
|
||||
},
|
||||
{
|
||||
"title": "Cat Article",
|
||||
"content": """Cats are independent and often aloof pets.
|
||||
They are known for their grooming habits and ability to hunt small prey.""",
|
||||
"document_id": "cat1",
|
||||
"semantic_identifier": "Cat Basics",
|
||||
"blurb": "An introduction to cats as pets and their typical behaviors.",
|
||||
"source_type": "Pet Information",
|
||||
"score": 0.92,
|
||||
"link": "http://example.com/cats/basics",
|
||||
"updated_at": datetime.datetime.now().isoformat(),
|
||||
},
|
||||
{
|
||||
"title": "Hamster Article",
|
||||
"content": """Hamsters are small rodents that make popular pocket pets.
|
||||
They are nocturnal and require a cage with exercise wheels and tunnels.""",
|
||||
"document_id": "hamster1",
|
||||
"semantic_identifier": "Hamster Care",
|
||||
"blurb": "Essential information for keeping hamsters as pets.",
|
||||
"source_type": "Pet Information",
|
||||
"score": 0.88,
|
||||
"link": "http://example.com/hamsters/care",
|
||||
"updated_at": datetime.datetime.now().isoformat(),
|
||||
},
|
||||
]
|
||||
|
||||
with requests_mock.Mocker() as m:
|
||||
m.get(
|
||||
"http://localhost:8000/search",
|
||||
json=mock_results,
|
||||
)
|
||||
|
||||
tool_args = {"query": "pet information"}
|
||||
|
||||
tool_responses = build_and_run_tool(tool_definition, tool_args)
|
||||
|
||||
assert len(tool_responses) > 0
|
||||
for packet in tool_responses:
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
print(f"Custom tool response: {custom_tool_response.tool_result}")
|
||||
assert isinstance(custom_tool_response.tool_result, list)
|
||||
assert len(custom_tool_response.tool_result) == 3
|
||||
|
||||
expected_document_ids = ["dog1", "cat1", "hamster1"]
|
||||
expected_titles = ["Dog Article", "Cat Article", "Hamster Article"]
|
||||
|
||||
for index, expected_id in enumerate(expected_document_ids):
|
||||
result = custom_tool_response.tool_result[index]
|
||||
assert result["document_id"] == expected_id
|
||||
assert result["title"] == expected_titles[index]
|
||||
assert "content" in result
|
||||
assert "semantic_identifier" in result
|
||||
assert "blurb" in result
|
||||
assert "source_type" in result
|
||||
assert "score" in result
|
||||
assert "link" in result
|
||||
assert "updated_at" in result
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
print(f"Received unexpected packet type: {type(packet)}")
|
||||
|
||||
print("Custom search tool ran successfully with processed response")
|
||||
@@ -0,0 +1,71 @@
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from tests.integration.common_utils.managers.tool import ToolManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestTool
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_custom_search_tool(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Define the tool using the provided OpenAPI schema
|
||||
tool_definition = {
|
||||
"info": {
|
||||
"title": "Simple API",
|
||||
"version": "1.0.0",
|
||||
"description": "A minimal API for demonstration",
|
||||
},
|
||||
"paths": {
|
||||
"/hello": {
|
||||
"get": {
|
||||
"summary": "Say hello",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"parameters": [
|
||||
{
|
||||
"in": "query",
|
||||
"name": "name",
|
||||
"schema": {"type": "string"},
|
||||
"required": False,
|
||||
"description": "Name to greet",
|
||||
}
|
||||
],
|
||||
"operationId": "sayHello",
|
||||
}
|
||||
}
|
||||
},
|
||||
"openapi": "3.0.0",
|
||||
"servers": [{"url": "http://localhost:8000"}],
|
||||
}
|
||||
|
||||
# Create the custom tool with error handling
|
||||
try:
|
||||
print("Attempting to create custom tool...")
|
||||
custom_tool: DATestTool = ToolManager.create(
|
||||
name="Sample Search Tool",
|
||||
description="A sample tool that performs a search",
|
||||
definition=tool_definition,
|
||||
custom_headers=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
print(f"Custom tool created successfully with ID: {custom_tool.id}")
|
||||
except HTTPError as e:
|
||||
pytest.fail(f"Failed to create custom tool: {e}")
|
||||
|
||||
# Verify the tool was created successfully
|
||||
assert ToolManager.verify(custom_tool, user_performing_action=admin_user)
|
||||
print("Custom tool verified successfully")
|
||||
@@ -0,0 +1,102 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolResponseType
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from tests.integration.common_utils.managers.tool import ToolManager
|
||||
from tests.integration.common_utils.test_models import DATestTool
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def create_and_verify_custom_tool(
|
||||
tool_definition: Dict[str, Any], user: DATestUser, name: str, description: str
|
||||
) -> DATestTool:
|
||||
"""Create and verify a custom tool."""
|
||||
try:
|
||||
print(f"Attempting to create custom tool '{name}'...")
|
||||
custom_tool: DATestTool = ToolManager.create(
|
||||
name=name,
|
||||
description=description,
|
||||
definition=tool_definition,
|
||||
custom_headers=[],
|
||||
user_performing_action=user,
|
||||
)
|
||||
print(f"Custom tool '{name}' created successfully with ID: {custom_tool.id}")
|
||||
except HTTPError as e:
|
||||
pytest.fail(f"Failed to create custom tool '{name}': {e}")
|
||||
|
||||
# Verify the tool was created successfully
|
||||
assert ToolManager.verify(custom_tool, user_performing_action=user)
|
||||
print(f"Custom tool '{name}' verified successfully")
|
||||
return custom_tool
|
||||
|
||||
|
||||
def build_and_run_tool(
|
||||
tool_definition: Dict[str, Any], tool_args: Dict[str, Any]
|
||||
) -> List[Any]:
|
||||
"""Build and run the custom tool."""
|
||||
# Build the custom tool from the definition
|
||||
custom_tools = build_custom_tools_from_openapi_schema_and_headers(
|
||||
openapi_schema=tool_definition,
|
||||
custom_headers=[],
|
||||
)
|
||||
|
||||
# Ensure only one tool was built
|
||||
assert len(custom_tools) == 1
|
||||
tool = custom_tools[0]
|
||||
|
||||
# Run the tool using ToolRunner
|
||||
tool_runner = ToolRunner(tool=tool, args=tool_args)
|
||||
return list(tool_runner.tool_responses())
|
||||
|
||||
|
||||
def process_tool_responses(tool_responses: List[Any]) -> None:
|
||||
assert len(tool_responses) > 0
|
||||
for packet in tool_responses:
|
||||
if isinstance(packet, ToolResponse) and packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
|
||||
if custom_tool_response.response_type in (
|
||||
CustomToolResponseType.CSV,
|
||||
CustomToolResponseType.IMAGE,
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files = [
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type
|
||||
== CustomToolResponseType.IMAGE
|
||||
else ChatFileType.CSV,
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
print(f"Received file IDs: {[str(file_id) for file_id in file_ids]}")
|
||||
assert len(file_ids) > 0
|
||||
assert all(isinstance(file_id, str) for file_id in file_ids)
|
||||
assert len(ai_message_files) == len(file_ids)
|
||||
# Additional processing can be added here if needed
|
||||
else:
|
||||
custom_tool_result = custom_tool_response.tool_result
|
||||
tool_name = custom_tool_response.tool_name
|
||||
print(f"Custom tool '{tool_name}' response: {custom_tool_result}")
|
||||
assert custom_tool_result is not None
|
||||
assert tool_name is not None
|
||||
# Additional processing can be added here if needed
|
||||
else:
|
||||
print(f"Received unexpected packet: {packet}")
|
||||
@@ -14,6 +14,7 @@ from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolResponseType
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
validate_openapi_schema,
|
||||
)
|
||||
@@ -215,7 +216,7 @@ class TestCustomTool(unittest.TestCase):
|
||||
mock_response = ToolResponse(
|
||||
id=CUSTOM_TOOL_RESPONSE_ID,
|
||||
response=CustomToolCallSummary(
|
||||
response_type="json",
|
||||
response_type=CustomToolResponseType.JSON,
|
||||
tool_name="getAssistant",
|
||||
tool_result={"id": "789", "name": "Final Assistant"},
|
||||
),
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
buildDocumentSummaryDisplay,
|
||||
} from "@/components/search/DocumentDisplay";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
import { nonSearchableSources, ValidSources } from "@/lib/types";
|
||||
|
||||
interface DocumentDisplayProps {
|
||||
document: DanswerDocument;
|
||||
@@ -101,13 +102,14 @@ export function ChatDocumentDisplay({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!isInternet && (
|
||||
<DocumentSelector
|
||||
isSelected={isSelected}
|
||||
handleSelect={() => handleSelect(document.document_id)}
|
||||
isDisabled={tokenLimitReached && !isSelected}
|
||||
/>
|
||||
)}
|
||||
{!isInternet &&
|
||||
!nonSearchableSources.includes(document.source_type) && (
|
||||
<DocumentSelector
|
||||
isSelected={isSelected}
|
||||
handleSelect={() => handleSelect(document.document_id)}
|
||||
isDisabled={tokenLimitReached && !isSelected}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
<div className="mt-1">
|
||||
|
||||
@@ -10,6 +10,7 @@ export function SourceIcon({
|
||||
sourceType: ValidSources;
|
||||
iconSize: number;
|
||||
}) {
|
||||
console.log("sourceType is ", sourceType);
|
||||
return getSourceMetadata(sourceType).icon({
|
||||
size: iconSize,
|
||||
});
|
||||
|
||||
@@ -6,7 +6,7 @@ import React, {
|
||||
useEffect,
|
||||
} from "react";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { ConfigurableSources } from "@/lib/types";
|
||||
|
||||
interface FormContextType {
|
||||
formStep: number;
|
||||
@@ -15,7 +15,7 @@ interface FormContextType {
|
||||
nextFormStep: (contract?: string) => void;
|
||||
prevFormStep: () => void;
|
||||
formStepToLast: () => void;
|
||||
connector: ValidSources;
|
||||
connector: ConfigurableSources;
|
||||
setFormStep: React.Dispatch<React.SetStateAction<number>>;
|
||||
allowAdvanced: boolean;
|
||||
setAllowAdvanced: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
@@ -27,7 +27,7 @@ const FormContext = createContext<FormContextType | undefined>(undefined);
|
||||
|
||||
export const FormProvider: React.FC<{
|
||||
children: ReactNode;
|
||||
connector: ValidSources;
|
||||
connector: ConfigurableSources;
|
||||
}> = ({ children, connector }) => {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { ConfigurableSources, ValidSources } from "@/lib/types";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { FaSwatchbook } from "react-icons/fa";
|
||||
@@ -35,7 +35,7 @@ export default function CredentialSection({
|
||||
refresh,
|
||||
}: {
|
||||
ccPair: CCPairFullInfo;
|
||||
sourceType: ValidSources;
|
||||
sourceType: ConfigurableSources;
|
||||
refresh: () => void;
|
||||
}) {
|
||||
const makeShowCreateCredential = () => {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import React, { useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { ConfigurableSources } from "@/lib/types";
|
||||
import { FaAccusoft } from "react-icons/fa";
|
||||
import { submitCredential } from "@/components/admin/connectors/CredentialForm";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
@@ -71,7 +71,7 @@ export default function CreateCredential({
|
||||
}: {
|
||||
// Source information
|
||||
hideSource?: boolean; // hides docs link
|
||||
sourceType: ValidSources;
|
||||
sourceType: ConfigurableSources;
|
||||
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Modal } from "@/components/Modal";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { ConfigurableSources, ValidSources } from "@/lib/types";
|
||||
import {
|
||||
EditIcon,
|
||||
NewChatIcon,
|
||||
@@ -160,7 +160,7 @@ export default function ModifyCredential({
|
||||
defaultedCredential?: Credential<any>;
|
||||
credentials: Credential<any>[];
|
||||
editableCredentials: Credential<any>[];
|
||||
source: ValidSources;
|
||||
source: ConfigurableSources;
|
||||
|
||||
onSwitch?: (newCredential: Credential<any>) => void;
|
||||
onSwap?: (newCredential: Credential<any>, connectorId: number) => void;
|
||||
|
||||
@@ -1129,7 +1129,7 @@ export const defaultRefreshFreqMinutes = 30; // 30 minutes
|
||||
// CONNECTORS
|
||||
export interface ConnectorBase<T> {
|
||||
name: string;
|
||||
source: ValidSources;
|
||||
source: ConfigurableSources;
|
||||
input_type: ValidInputTypes;
|
||||
connector_specific_config: T;
|
||||
refresh_freq: number | null;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ValidSources } from "../types";
|
||||
import { ConfigurableSources, ValidSources } from "../types";
|
||||
|
||||
export interface CredentialBase<T> {
|
||||
credential_json: T;
|
||||
@@ -195,7 +195,7 @@ export interface FirefliesCredentialJson {
|
||||
export interface MediaWikiCredentialJson {}
|
||||
export interface WikipediaCredentialJson extends MediaWikiCredentialJson {}
|
||||
|
||||
export const credentialTemplates: Record<ValidSources, any> = {
|
||||
export const credentialTemplates: Record<ConfigurableSources, any> = {
|
||||
github: { github_access_token: "" } as GithubCredentialJson,
|
||||
gitlab: {
|
||||
gitlab_url: "",
|
||||
@@ -304,8 +304,6 @@ export const credentialTemplates: Record<ValidSources, any> = {
|
||||
wikipedia: null,
|
||||
mediawiki: null,
|
||||
web: null,
|
||||
not_applicable: null,
|
||||
ingestion_api: null,
|
||||
|
||||
// NOTE: These are Special Cases
|
||||
google_drive: { google_tokens: "" } as GoogleDriveCredentialJson,
|
||||
|
||||
@@ -40,12 +40,9 @@ import {
|
||||
FirefliesIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { ValidSources } from "./types";
|
||||
import {
|
||||
DanswerDocument,
|
||||
SourceCategory,
|
||||
SourceMetadata,
|
||||
} from "./search/interfaces";
|
||||
import { SourceCategory, SourceMetadata } from "./search/interfaces";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { FaHammer } from "react-icons/fa";
|
||||
|
||||
interface PartialSourceMetadata {
|
||||
icon: React.FC<{ size?: number; className?: string }>;
|
||||
@@ -59,6 +56,12 @@ type SourceMap = {
|
||||
};
|
||||
|
||||
const SOURCE_METADATA_MAP: SourceMap = {
|
||||
custom_tool: {
|
||||
icon: FaHammer,
|
||||
displayName: "Custom Tool",
|
||||
category: SourceCategory.Other,
|
||||
docs: "https://docs.danswer.dev/connectors/custom_tool",
|
||||
},
|
||||
web: {
|
||||
icon: GlobeIcon,
|
||||
displayName: "Web",
|
||||
|
||||
@@ -254,6 +254,7 @@ export interface UserGroup {
|
||||
}
|
||||
|
||||
const validSources = [
|
||||
"custom_tool",
|
||||
"web",
|
||||
"github",
|
||||
"gitlab",
|
||||
@@ -301,9 +302,11 @@ export type ValidSources = (typeof validSources)[number];
|
||||
// The valid sources that are actually valid to select in the UI
|
||||
export type ConfigurableSources = Exclude<
|
||||
ValidSources,
|
||||
"not_applicable" | "ingestion_api"
|
||||
"not_applicable" | "ingestion_api" | "custom_tool"
|
||||
>;
|
||||
|
||||
export const nonSearchableSources: ValidSources[] = ["custom_tool"];
|
||||
|
||||
// The sources that have auto-sync support on the backend
|
||||
export const validAutoSyncSources = [
|
||||
"confluence",
|
||||
|
||||
Reference in New Issue
Block a user