Compare commits

...

2 Commits

Author SHA1 Message Date
pablodanswer
9633152f4d post rebase cleanup 2024-11-23 12:06:39 -08:00
pablodanswer
0370023e08 add search result capability to custom tools 2024-11-23 12:00:20 -08:00
25 changed files with 1083 additions and 106 deletions

View File

@@ -81,6 +81,7 @@ from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from danswer.tools.tool_implementations.custom.models import CustomToolResponseType
from danswer.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
@@ -160,7 +161,6 @@ def _handle_search_tool_response_summary(
]
else:
reference_db_search_docs = selected_search_docs
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
@@ -603,6 +603,8 @@ def stream_chat_message_objects(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
additional_headers=custom_tool_additional_headers,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
),
)
tools: list[Tool] = []
@@ -659,6 +661,7 @@ def stream_chat_message_objects(
else False
),
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
@@ -715,15 +718,17 @@ def stream_chat_message_objects(
custom_tool_response = cast(CustomToolCallSummary, packet.response)
if (
custom_tool_response.response_type == "image"
or custom_tool_response.response_type == "csv"
custom_tool_response.response_type == CustomToolResponseType.CSV
or custom_tool_response.response_type
== CustomToolResponseType.IMAGE
):
file_ids = custom_tool_response.tool_result.file_ids
ai_message_files = [
FileDescriptor(
id=str(file_id),
type=ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
if custom_tool_response.response_type
== CustomToolResponseType.IMAGE
else ChatFileType.CSV,
)
for file_id in file_ids

View File

@@ -88,6 +88,7 @@ DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
CUSTOM_TOOL = "custom_tool"
INGESTION_API = "ingestion_api"
SLACK = "slack"
WEB = "web"

View File

@@ -125,6 +125,8 @@ class CustomToolConfig(BaseModel):
chat_session_id: UUID | None = None
message_id: int | None = None
additional_headers: dict[str, str] | None = None
answer_style_config: AnswerStyleConfig | None = None
prompt_config: PromptConfig | None = None
def construct_tools(
@@ -228,6 +230,8 @@ def construct_tools(
custom_tool_config.additional_headers or {}
)
),
answer_style_config=custom_tool_config.answer_style_config,
prompt_config=custom_tool_config.prompt_config,
),
)

View File

@@ -12,17 +12,25 @@ from typing import List
import requests
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from requests import JSONDecodeError
from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.models import LlmDoc
from danswer.chat.models import SectionRelevancePiece
from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_with_default_tenant
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.interfaces import LLM
from danswer.search.models import DocumentSource
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.tools.base_tool import BaseTool
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
@@ -42,6 +50,12 @@ from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
TOOL_ARG_USER_PROMPT,
)
from danswer.tools.tool_implementations.custom.custom_tool_prompts import USE_TOOL
from danswer.tools.tool_implementations.custom.models import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.tool_implementations.custom.models import CustomToolCallSummary
from danswer.tools.tool_implementations.custom.models import CustomToolFileResponse
from danswer.tools.tool_implementations.custom.models import CustomToolResponseType
from danswer.tools.tool_implementations.custom.models import CustomToolSearchResponse
from danswer.tools.tool_implementations.custom.models import CustomToolSearchResult
from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec
from danswer.tools.tool_implementations.custom.openapi_parsing import (
openapi_to_method_specs,
@@ -51,35 +65,39 @@ from danswer.tools.tool_implementations.custom.openapi_parsing import REQUEST_BO
from danswer.tools.tool_implementations.custom.openapi_parsing import (
validate_openapi_schema,
)
from danswer.tools.tool_implementations.custom.prompt import (
build_custom_image_generation_user_prompt,
from danswer.tools.tool_implementations.file_like_tool_utils import (
build_next_prompt_for_file_like_tool,
)
from danswer.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.tool_implementations.search.search_utils import llm_doc_to_dict
from danswer.tools.tool_implementations.search_like_tool_utils import (
build_next_prompt_for_search_like_tool,
)
from danswer.utils.headers import header_list_to_header_dict
from danswer.utils.headers import HeaderItemDict
from danswer.utils.logger import setup_logger
from danswer.utils.special_types import JSON_ro
logger = setup_logger()
CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response"
class CustomToolFileResponse(BaseModel):
file_ids: List[str] # References to saved images or CSVs
class CustomToolCallSummary(BaseModel):
tool_name: str
response_type: str # e.g., 'json', 'image', 'csv', 'graph'
tool_result: Any # The response data
class CustomTool(BaseTool):
def __init__(
self,
method_spec: MethodSpec,
base_url: str,
answer_style_config: AnswerStyleConfig | None = None,
custom_headers: list[HeaderItemDict] | None = None,
prompt_config: PromptConfig | None = None,
) -> None:
self._base_url = base_url
self._method_spec = method_spec
@@ -90,6 +108,8 @@ class CustomTool(BaseTool):
self.headers = (
header_list_to_header_dict(custom_headers) if custom_headers else {}
)
self.answer_style_config = answer_style_config
self.prompt_config = prompt_config
@property
def name(self) -> str:
@@ -111,14 +131,44 @@ class CustomTool(BaseTool):
def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
response = cast(CustomToolCallSummary, args[0].response)
final_context_docs_response = next(
(
response
for response in args
if response.id == FINAL_CONTEXT_DOCUMENTS_ID
),
None,
)
if response.response_type == "image" or response.response_type == "csv":
image_response = cast(CustomToolFileResponse, response.tool_result)
return json.dumps({"file_ids": image_response.file_ids})
# Handle the search type response
if final_context_docs_response:
final_context_docs = cast(
list[LlmDoc], final_context_docs_response.response
)
return json.dumps(
{
"search_results": [
llm_doc_to_dict(doc, ind)
for ind, doc in enumerate(final_context_docs)
]
}
)
# For JSON or other responses, return as-is
return json.dumps(response.tool_result)
# Handle other response types
response = args[0].response
if isinstance(response, CustomToolCallSummary):
if response.response_type in [
CustomToolResponseType.IMAGE,
CustomToolResponseType.CSV,
]:
file_response = cast(CustomToolFileResponse, response.tool_result)
return json.dumps({"file_ids": file_response.file_ids})
# For JSON or other responses, return as-is
return json.dumps(response.tool_result)
# If it's not a CustomToolCallSummary or search result, return as-is
return json.dumps(response)
"""For LLMs which do NOT support explicit tool calling"""
@@ -219,6 +269,75 @@ class CustomTool(BaseTool):
"""Actual execution of the tool"""
def _handle_search_like_tool_response(
self, tool_result: CustomToolSearchResponse
) -> Generator[ToolResponse, None, None]:
# Convert results to InferenceSections
inference_sections = []
for result in tool_result.results:
chunk = InferenceChunk(
document_id=result.document_id,
content=result.content,
semantic_identifier=result.document_id, # using document_id as semantic_identifier
blurb=result.blurb,
source_type=DocumentSource.CUSTOM_TOOL,
# Use defaults
chunk_id=0,
source_links={},
section_continuation=False,
title=result.title,
boost=0,
recency_bias=0, # Default recency bias
hidden=False,
score=0,
metadata={},
match_highlights=[],
updated_at=result.updated_at,
)
# We assume that each search result belongs to different documents
section = InferenceSection(
center_chunk=chunk,
chunks=[chunk],
combined_content=chunk.content,
)
inference_sections.append(section)
search_response_summary = SearchResponseSummary(
rephrased_query=None,
top_sections=inference_sections,
predicted_flow=None,
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=0.0,
)
yield ToolResponse(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=search_response_summary,
)
# Build selected sections for relevance (assuming all are relevant)
selected_sections = [
SectionRelevancePiece(
relevant=True,
document_id=section.center_chunk.document_id,
chunk_id=0,
)
for section in inference_sections
]
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,
response=selected_sections,
)
llm_docs = [
llm_doc_from_inference_section(section) for section in inference_sections
]
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
request_body = kwargs.get(REQUEST_BODY)
@@ -243,44 +362,67 @@ class CustomTool(BaseTool):
content_type = response.headers.get("Content-Type", "")
tool_result: Any
response_type: str
response_type: CustomToolResponseType
if "text/csv" in content_type:
file_ids = self._save_and_get_file_references(
response.content, content_type
)
tool_result = CustomToolFileResponse(file_ids=file_ids)
response_type = "csv"
response_type = CustomToolResponseType.CSV
elif "image/" in content_type:
file_ids = self._save_and_get_file_references(
response.content, content_type
)
tool_result = CustomToolFileResponse(file_ids=file_ids)
response_type = "image"
response_type = CustomToolResponseType.IMAGE
elif content_type == "application/json":
tool_result = response.json()
# Check if the response is a search result
if isinstance(tool_result, list) and all(
"content" in item for item in tool_result
):
# Process as search results
search_results = [
CustomToolSearchResult(**item) for item in tool_result
]
tool_result = CustomToolSearchResponse(results=search_results)
response_type = CustomToolResponseType.SEARCH
else:
# Process as generic JSON
response_type = CustomToolResponseType.JSON
else:
# Default to JSON if content type is not specified
try:
tool_result = response.json()
response_type = "json"
response_type = CustomToolResponseType.JSON
except JSONDecodeError:
logger.exception(
f"Failed to parse response as JSON for tool '{self._name}'"
)
tool_result = response.text
response_type = "text"
response_type = CustomToolResponseType.TEXT
logger.info(
f"Returning tool response for {self._name} with type {response_type}"
)
if response_type == CustomToolResponseType.SEARCH and isinstance(
tool_result, CustomToolSearchResponse
):
yield from self._handle_search_like_tool_response(tool_result)
yield ToolResponse(
id=CUSTOM_TOOL_RESPONSE_ID,
response=CustomToolCallSummary(
tool_name=self._name,
response_type=response_type,
tool_result=tool_result,
),
)
else:
yield ToolResponse(
id=CUSTOM_TOOL_RESPONSE_ID,
response=CustomToolCallSummary(
tool_name=self._name,
response_type=response_type,
tool_result=tool_result,
),
)
def build_next_prompt(
self,
@@ -289,10 +431,26 @@ class CustomTool(BaseTool):
tool_responses: list[ToolResponse],
using_tool_calling_llm: bool,
) -> AnswerPromptBuilder:
response = cast(CustomToolCallSummary, tool_responses[0].response)
response = tool_responses[0].response
if isinstance(response, SearchResponseSummary) and self.prompt_config:
if not self.answer_style_config or self.answer_style_config.citation_config:
raise ValueError("Citation config is required for search tools")
return build_next_prompt_for_search_like_tool(
prompt_builder=prompt_builder,
tool_call_summary=tool_call_summary,
tool_responses=tool_responses,
using_tool_calling_llm=using_tool_calling_llm,
answer_style_config=self.answer_style_config,
prompt_config=self.prompt_config,
)
# Handle non-file responses using parent class behavior
if response.response_type not in ["image", "csv"]:
if response.response_type not in [
CustomToolResponseType.IMAGE,
CustomToolResponseType.CSV,
]:
return super().build_next_prompt(
prompt_builder,
tool_call_summary,
@@ -300,54 +458,47 @@ class CustomTool(BaseTool):
using_tool_calling_llm,
)
# Handle image and CSV file responses
# Handle file responses
file_type = (
ChatFileType.IMAGE
if response.response_type == "image"
if response.response_type == CustomToolResponseType.IMAGE
else ChatFileType.CSV
)
# Load files from storage
files = []
with get_session_with_default_tenant() as db_session:
file_store = get_default_file_store(db_session)
for file_id in response.tool_result.file_ids:
try:
file_io = file_store.read_file(file_id, mode="b")
files.append(
InMemoryChatFile(
file_id=file_id,
filename=file_id,
content=file_io.read(),
file_type=file_type,
)
)
except Exception:
logger.exception(f"Failed to read file {file_id}")
# Update prompt with file content
prompt_builder.update_user_prompt(
build_custom_image_generation_user_prompt(
query=prompt_builder.get_user_message_content(),
files=files,
file_type=file_type,
)
)
return prompt_builder
return build_next_prompt_for_file_like_tool(
prompt_builder,
response.tool_result.file_ids,
file_type=file_type,
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
response = cast(CustomToolCallSummary, args[0].response)
if isinstance(response.tool_result, CustomToolFileResponse):
return response.tool_result.model_dump()
return response.tool_result
if hasattr(response, "tool_result"):
if isinstance(response.tool_result, CustomToolFileResponse):
return response.tool_result.model_dump()
return response.tool_result
else:
final_docs = cast(
list[LlmDoc],
next(
arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS_ID
),
)
# NOTE: need to do this json.loads(doc.json()) stuff because there are some
# subfields that are not serializable by default (datetime)
# this forces pydantic to make them JSON serializable for us
return [json.loads(doc.model_dump_json()) for doc in final_docs]
# return response.tool_result
def build_custom_tools_from_openapi_schema_and_headers(
openapi_schema: dict[str, Any],
answer_style_config: AnswerStyleConfig | None = None,
custom_headers: list[HeaderItemDict] | None = None,
dynamic_schema_info: DynamicSchemaInfo | None = None,
prompt_config: PromptConfig | None = None,
) -> list[CustomTool]:
if dynamic_schema_info:
# Process dynamic schema information
@@ -366,7 +517,8 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
CustomTool(method_spec, url, answer_style_config, custom_headers, prompt_config)
for method_spec in method_specs
]

View File

@@ -0,0 +1,39 @@
from datetime import datetime
from enum import Enum
from typing import Any
from typing import List
from pydantic import BaseModel
CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response"
class CustomToolResponseType(Enum):
IMAGE = "image"
CSV = "csv"
JSON = "json"
TEXT = "text"
SEARCH = "search"
class CustomToolFileResponse(BaseModel):
file_ids: List[str] # References to saved images or CSVs
class CustomToolCallSummary(BaseModel):
tool_name: str
response_type: CustomToolResponseType
tool_result: Any # The response data
class CustomToolSearchResult(BaseModel):
document_id: str
content: str # content of the search result
blurb: str # used to display in the UI
title: str
link: str | None = None # optional source link
updated_at: datetime | None = None
class CustomToolSearchResponse(BaseModel):
results: List[CustomToolSearchResult]

View File

@@ -12,7 +12,7 @@ Can you please summarize it in a sentence or two? Do NOT include image urls or b
"""
def build_custom_image_generation_user_prompt(
def build_custom_file_generation_user_prompt(
query: str, file_type: ChatFileType, files: list[InMemoryChatFile] | None = None
) -> HumanMessage:
return HumanMessage(

View File

@@ -0,0 +1,64 @@
from langchain_core.messages import HumanMessage
from danswer.db.engine import get_session_with_tenant
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.utils import build_content_with_imgs
from danswer.utils.logger import setup_logger
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
logger = setup_logger()
IMG_GENERATION_SUMMARY_PROMPT = """
You have just created the attached images in response to the following query: "{query}".
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
"""
def build_file_generation_user_prompt(
query: str, files: list[InMemoryChatFile] | None = None
) -> HumanMessage:
return HumanMessage(
content=build_content_with_imgs(
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
files=files,
)
)
def build_next_prompt_for_file_like_tool(
prompt_builder: AnswerPromptBuilder,
file_ids: list[str],
file_type: ChatFileType,
) -> AnswerPromptBuilder:
with get_session_with_tenant() as db_session:
file_store = get_default_file_store(db_session)
files = []
for file_id in file_ids:
try:
file_io = file_store.read_file(file_id, mode="b")
files.append(
InMemoryChatFile(
file_id=file_id,
filename=file_id,
content=file_io.read(),
file_type=file_type,
)
)
except Exception:
logger.exception(f"Failed to read file {file_id}")
# Update prompt with file content
prompt_builder.update_user_prompt(
build_file_generation_user_prompt(
query=prompt_builder.get_user_message_content(),
files=files,
)
)
return prompt_builder

View File

@@ -0,0 +1,153 @@
from uuid import uuid4
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestTool
from tests.integration.common_utils.test_models import DATestUser
class ToolManager:
@staticmethod
def create(
name: str | None = None,
description: str | None = None,
definition: dict | None = None,
custom_headers: list[dict[str, str]] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestTool:
name = name or f"test-tool-{uuid4()}"
description = description or f"Description for {name}"
definition = definition or {}
# Validate the tool definition before creating
validate_response = requests.post(
f"{API_SERVER_URL}/admin/tool/custom/validate",
json={"definition": definition},
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
if validate_response.status_code != 200:
raise Exception(
f"Tool validation failed: {validate_response.json().get('detail', 'Unknown error')}"
)
tool_creation_request = {
"name": name,
"description": description,
"definition": definition,
}
if custom_headers is not None:
tool_creation_request["custom_headers"] = custom_headers
response = requests.post(
f"{API_SERVER_URL}/admin/tool/custom",
json=tool_creation_request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
if response.status_code == 400:
raise Exception(
f"Tool creation failed: {response.json().get('detail', 'Unknown error')}"
)
response.raise_for_status()
tool_data = response.json()
return DATestTool(
id=tool_data["id"],
name=tool_data["name"],
description=tool_data["description"],
definition=tool_data["definition"],
custom_headers=tool_data.get("custom_headers", []),
)
@staticmethod
def edit(
tool: DATestTool,
name: str | None = None,
description: str | None = None,
definition: dict | None = None,
custom_headers: dict[str, str] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestTool:
tool_update_request = {
"name": name or tool.name,
"description": description or tool.description,
"definition": definition or tool.definition,
"custom_headers": custom_headers or tool.custom_headers,
}
response = requests.put(
f"{API_SERVER_URL}/admin/tool/custom/{tool.id}",
json=tool_update_request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
updated_tool_data = response.json()
return DATestTool(
id=updated_tool_data["id"],
name=updated_tool_data["name"],
description=updated_tool_data["description"],
definition=updated_tool_data["definition"],
custom_headers=updated_tool_data.get("custom_headers", []),
)
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[DATestTool]:
response = requests.get(
f"{API_SERVER_URL}/tool",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [
DATestTool(
id=tool["id"],
name=tool["name"],
description=tool.get("description", ""),
definition=tool.get("definition") or {},
custom_headers=tool.get("custom_headers") or [],
)
for tool in response.json()
]
@staticmethod
def verify(
tool: DATestTool,
user_performing_action: DATestUser | None = None,
) -> bool:
all_tools = ToolManager.get_all(user_performing_action)
for fetched_tool in all_tools:
if fetched_tool.id == tool.id:
return (
fetched_tool.name == tool.name
and fetched_tool.description == tool.description
and fetched_tool.definition == tool.definition
and fetched_tool.custom_headers == tool.custom_headers
)
return False
@staticmethod
def delete(
tool: DATestTool,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/admin/tool/custom/{tool.id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
return response.ok

View File

@@ -150,3 +150,16 @@ class StreamedResponse(BaseModel):
relevance_summaries: list[dict[str, Any]] | None = None
tool_result: Any | None = None
user: str | None = None
class DATestToolHeader(BaseModel):
key: str
value: str
class DATestTool(BaseModel):
id: int
name: str
description: str
definition: dict[str, Any]
custom_headers: list[DATestToolHeader]

View File

@@ -0,0 +1,88 @@
import requests_mock
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
build_and_run_tool,
)
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
create_and_verify_custom_tool,
)
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
process_tool_responses,
)
# Absolute imports for helper functions
def test_custom_csv_tool(reset: None) -> None:
# Create admin user directly in the test
admin_user: DATestUser = UserManager.create(name="admin_user")
# Define the tool using the provided OpenAPI schema for CSV export
tool_definition = {
"openapi": "3.0.0",
"info": {
"title": "Sample CSV Export API",
"version": "1.0.0",
"description": "An API for exporting data in CSV format",
},
"paths": {
"/export": {
"get": {
"summary": "Export data in CSV format",
"operationId": "exportCsv",
"responses": {
"200": {
"description": "Successful CSV export",
"content": {
"text/csv": {
"schema": {
"type": "string",
"format": "binary",
}
}
},
}
},
"parameters": [
{
"in": "query",
"name": "date",
"schema": {"type": "string", "format": "date"},
"required": True,
"description": "The date for which to export data",
}
],
}
}
},
"servers": [{"url": "http://localhost:8000"}],
}
create_and_verify_custom_tool(
tool_definition=tool_definition,
user=admin_user,
name="Sample CSV Export Tool",
description="A sample tool that exports data in CSV format",
)
mock_csv_content = "id,name,age\n1,Alice,30\n2,Bob,25\n3,Charlie,35\n"
with requests_mock.Mocker() as m:
m.get(
"http://localhost:8000/export",
text=mock_csv_content,
headers={"Content-Type": "text/csv"},
)
# Prepare arguments for the tool
tool_args = {"date": "2023-10-01"}
# Build and run the tool
tool_responses = build_and_run_tool(tool_definition, tool_args)
# Process the tool responses
process_tool_responses(tool_responses)
print("Custom CSV tool ran successfully with processed response")

View File

@@ -0,0 +1,97 @@
import requests_mock
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
build_and_run_tool,
)
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
create_and_verify_custom_tool,
)
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
process_tool_responses,
)
# Absolute imports for helper functions
def test_custom_image_tool(reset: None) -> None:
# Create admin user directly in the test
admin_user: DATestUser = UserManager.create(name="admin_user")
# Define the tool using the provided OpenAPI schema for image fetching
tool_definition = {
"openapi": "3.0.0",
"info": {
"title": "Sample Image Fetch API",
"version": "1.0.0",
"description": "An API for fetching images",
},
"paths": {
"/image/{image_id}": {
"get": {
"summary": "Fetch an image",
"operationId": "fetchImage",
"responses": {
"200": {
"description": "Successful image fetch",
"content": {
"image/png": {
"schema": {
"type": "string",
"format": "binary",
}
},
"image/jpeg": {
"schema": {
"type": "string",
"format": "binary",
}
},
},
}
},
"parameters": [
{
"in": "path",
"name": "image_id",
"schema": {"type": "string"},
"required": True,
"description": "The ID of the image to fetch",
}
],
}
}
},
"servers": [{"url": "http://localhost:8000"}],
}
# Create and verify the custom tool
create_and_verify_custom_tool(
tool_definition=tool_definition,
user=admin_user,
name="Sample Image Fetch Tool",
description="A sample tool that fetches an image",
)
# Prepare mock image data for the API response
mock_image_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01..."
# Mock the requests made by the tool
with requests_mock.Mocker() as m:
m.get(
"http://localhost:8000/image/123",
content=mock_image_content,
headers={"Content-Type": "image/png"},
)
# Prepare arguments for the tool
tool_args = {"image_id": "123"}
# Build and run the tool
tool_responses = build_and_run_tool(tool_definition, tool_args)
# Process the tool responses
process_tool_responses(tool_responses)
print("Custom image tool ran successfully with processed response")

View File

@@ -0,0 +1,180 @@
import datetime
from typing import cast
import requests_mock
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
build_and_run_tool,
)
from tests.integration.tests.custom_tools.test_custom_tool_utils import (
create_and_verify_custom_tool,
)
def test_custom_search_tool(reset: None) -> None:
admin_user: DATestUser = UserManager.create(name="admin_user")
tool_definition = {
"info": {
"title": "Sample Search API",
"version": "1.0.0",
"description": "An API for performing searches and returning sample results",
},
"paths": {
"/search": {
"get": {
"summary": "Perform a search",
"responses": {
"200": {
"description": "Successful response",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/CustomToolSearchResult"
},
}
}
},
}
},
"parameters": [
{
"in": "query",
"name": "query",
"schema": {"type": "string"},
"required": True,
"description": "The search query",
}
],
"operationId": "searchApi",
}
}
},
"openapi": "3.0.0",
"servers": [{"url": "http://localhost:8000"}],
"components": {
"schemas": {
"CustomToolSearchResult": {
"type": "object",
"required": [
"content",
"document_id",
"semantic_identifier",
"blurb",
"source_type",
"score",
],
"properties": {
"content": {"type": "string"},
"document_id": {"type": "string"},
"semantic_identifier": {"type": "string"},
"blurb": {"type": "string"},
"source_type": {"type": "string"},
"score": {"type": "number", "format": "float"},
"link": {"type": "string", "nullable": True},
"title": {"type": "string", "nullable": True},
"updated_at": {
"type": "string",
"format": "date-time",
"nullable": True,
},
},
}
}
},
}
create_and_verify_custom_tool(
tool_definition=tool_definition,
user=admin_user,
name="Sample Search Tool",
description="A sample tool that performs a search",
)
mock_results = [
{
"title": "Dog Article",
"content": """Dogs are known for their loyalty and affection towards humans.
They come in various breeds, each with unique characteristics.""",
"document_id": "dog1",
"semantic_identifier": "Dog Basics",
"blurb": "An overview of dogs as pets and their general traits.",
"source_type": "Pet Information",
"score": 0.95,
"link": "http://example.com/dogs/basics",
"updated_at": datetime.datetime.now().isoformat(),
},
{
"title": "Cat Article",
"content": """Cats are independent and often aloof pets.
They are known for their grooming habits and ability to hunt small prey.""",
"document_id": "cat1",
"semantic_identifier": "Cat Basics",
"blurb": "An introduction to cats as pets and their typical behaviors.",
"source_type": "Pet Information",
"score": 0.92,
"link": "http://example.com/cats/basics",
"updated_at": datetime.datetime.now().isoformat(),
},
{
"title": "Hamster Article",
"content": """Hamsters are small rodents that make popular pocket pets.
They are nocturnal and require a cage with exercise wheels and tunnels.""",
"document_id": "hamster1",
"semantic_identifier": "Hamster Care",
"blurb": "Essential information for keeping hamsters as pets.",
"source_type": "Pet Information",
"score": 0.88,
"link": "http://example.com/hamsters/care",
"updated_at": datetime.datetime.now().isoformat(),
},
]
with requests_mock.Mocker() as m:
m.get(
"http://localhost:8000/search",
json=mock_results,
)
tool_args = {"query": "pet information"}
tool_responses = build_and_run_tool(tool_definition, tool_args)
assert len(tool_responses) > 0
for packet in tool_responses:
if isinstance(packet, ToolResponse):
if packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
print(f"Custom tool response: {custom_tool_response.tool_result}")
assert isinstance(custom_tool_response.tool_result, list)
assert len(custom_tool_response.tool_result) == 3
expected_document_ids = ["dog1", "cat1", "hamster1"]
expected_titles = ["Dog Article", "Cat Article", "Hamster Article"]
for index, expected_id in enumerate(expected_document_ids):
result = custom_tool_response.tool_result[index]
assert result["document_id"] == expected_id
assert result["title"] == expected_titles[index]
assert "content" in result
assert "semantic_identifier" in result
assert "blurb" in result
assert "source_type" in result
assert "score" in result
assert "link" in result
assert "updated_at" in result
else:
pass
else:
print(f"Received unexpected packet type: {type(packet)}")
print("Custom search tool ran successfully with processed response")

View File

@@ -0,0 +1,71 @@
import pytest
from requests.exceptions import HTTPError
from tests.integration.common_utils.managers.tool import ToolManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestTool
from tests.integration.common_utils.test_models import DATestUser
def test_custom_search_tool(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# Define the tool using the provided OpenAPI schema
tool_definition = {
"info": {
"title": "Simple API",
"version": "1.0.0",
"description": "A minimal API for demonstration",
},
"paths": {
"/hello": {
"get": {
"summary": "Say hello",
"responses": {
"200": {
"description": "Successful response",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {"message": {"type": "string"}},
}
}
},
}
},
"parameters": [
{
"in": "query",
"name": "name",
"schema": {"type": "string"},
"required": False,
"description": "Name to greet",
}
],
"operationId": "sayHello",
}
}
},
"openapi": "3.0.0",
"servers": [{"url": "http://localhost:8000"}],
}
# Create the custom tool with error handling
try:
print("Attempting to create custom tool...")
custom_tool: DATestTool = ToolManager.create(
name="Sample Search Tool",
description="A sample tool that performs a search",
definition=tool_definition,
custom_headers=[],
user_performing_action=admin_user,
)
print(f"Custom tool created successfully with ID: {custom_tool.id}")
except HTTPError as e:
pytest.fail(f"Failed to create custom tool: {e}")
# Verify the tool was created successfully
assert ToolManager.verify(custom_tool, user_performing_action=admin_user)
print("Custom tool verified successfully")

View File

@@ -0,0 +1,102 @@
from typing import Any
from typing import cast
from typing import Dict
from typing import List
import pytest
from requests.exceptions import HTTPError
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolResponseType
from danswer.tools.tool_runner import ToolRunner
from tests.integration.common_utils.managers.tool import ToolManager
from tests.integration.common_utils.test_models import DATestTool
from tests.integration.common_utils.test_models import DATestUser
def create_and_verify_custom_tool(
tool_definition: Dict[str, Any], user: DATestUser, name: str, description: str
) -> DATestTool:
"""Create and verify a custom tool."""
try:
print(f"Attempting to create custom tool '{name}'...")
custom_tool: DATestTool = ToolManager.create(
name=name,
description=description,
definition=tool_definition,
custom_headers=[],
user_performing_action=user,
)
print(f"Custom tool '{name}' created successfully with ID: {custom_tool.id}")
except HTTPError as e:
pytest.fail(f"Failed to create custom tool '{name}': {e}")
# Verify the tool was created successfully
assert ToolManager.verify(custom_tool, user_performing_action=user)
print(f"Custom tool '{name}' verified successfully")
return custom_tool
def build_and_run_tool(
tool_definition: Dict[str, Any], tool_args: Dict[str, Any]
) -> List[Any]:
"""Build and run the custom tool."""
# Build the custom tool from the definition
custom_tools = build_custom_tools_from_openapi_schema_and_headers(
openapi_schema=tool_definition,
custom_headers=[],
)
# Ensure only one tool was built
assert len(custom_tools) == 1
tool = custom_tools[0]
# Run the tool using ToolRunner
tool_runner = ToolRunner(tool=tool, args=tool_args)
return list(tool_runner.tool_responses())
def process_tool_responses(tool_responses: List[Any]) -> None:
assert len(tool_responses) > 0
for packet in tool_responses:
if isinstance(packet, ToolResponse) and packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
if custom_tool_response.response_type in (
CustomToolResponseType.CSV,
CustomToolResponseType.IMAGE,
):
file_ids = custom_tool_response.tool_result.file_ids
ai_message_files = [
FileDescriptor(
id=str(file_id),
type=ChatFileType.IMAGE
if custom_tool_response.response_type
== CustomToolResponseType.IMAGE
else ChatFileType.CSV,
)
for file_id in file_ids
]
print(f"Received file IDs: {[str(file_id) for file_id in file_ids]}")
assert len(file_ids) > 0
assert all(isinstance(file_id, str) for file_id in file_ids)
assert len(ai_message_files) == len(file_ids)
# Additional processing can be added here if needed
else:
custom_tool_result = custom_tool_response.tool_result
tool_name = custom_tool_response.tool_name
print(f"Custom tool '{tool_name}' response: {custom_tool_result}")
assert custom_tool_result is not None
assert tool_name is not None
# Additional processing can be added here if needed
else:
print(f"Received unexpected packet: {packet}")

View File

@@ -14,6 +14,7 @@ from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolResponseType
from danswer.tools.tool_implementations.custom.custom_tool import (
validate_openapi_schema,
)
@@ -215,7 +216,7 @@ class TestCustomTool(unittest.TestCase):
mock_response = ToolResponse(
id=CUSTOM_TOOL_RESPONSE_ID,
response=CustomToolCallSummary(
response_type="json",
response_type=CustomToolResponseType.JSON,
tool_name="getAssistant",
tool_result={"id": "789", "name": "Final Assistant"},
),

View File

@@ -9,6 +9,7 @@ import {
buildDocumentSummaryDisplay,
} from "@/components/search/DocumentDisplay";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { nonSearchableSources, ValidSources } from "@/lib/types";
interface DocumentDisplayProps {
document: DanswerDocument;
@@ -101,13 +102,14 @@ export function ChatDocumentDisplay({
</div>
)}
{!isInternet && (
<DocumentSelector
isSelected={isSelected}
handleSelect={() => handleSelect(document.document_id)}
isDisabled={tokenLimitReached && !isSelected}
/>
)}
{!isInternet &&
!nonSearchableSources.includes(document.source_type) && (
<DocumentSelector
isSelected={isSelected}
handleSelect={() => handleSelect(document.document_id)}
isDisabled={tokenLimitReached && !isSelected}
/>
)}
</div>
<div>
<div className="mt-1">

View File

@@ -10,6 +10,7 @@ export function SourceIcon({
sourceType: ValidSources;
iconSize: number;
}) {
console.log("sourceType is ", sourceType);
return getSourceMetadata(sourceType).icon({
size: iconSize,
});

View File

@@ -6,7 +6,7 @@ import React, {
useEffect,
} from "react";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import { ValidSources } from "@/lib/types";
import { ConfigurableSources } from "@/lib/types";
interface FormContextType {
formStep: number;
@@ -15,7 +15,7 @@ interface FormContextType {
nextFormStep: (contract?: string) => void;
prevFormStep: () => void;
formStepToLast: () => void;
connector: ValidSources;
connector: ConfigurableSources;
setFormStep: React.Dispatch<React.SetStateAction<number>>;
allowAdvanced: boolean;
setAllowAdvanced: React.Dispatch<React.SetStateAction<boolean>>;
@@ -27,7 +27,7 @@ const FormContext = createContext<FormContextType | undefined>(undefined);
export const FormProvider: React.FC<{
children: ReactNode;
connector: ValidSources;
connector: ConfigurableSources;
}> = ({ children, connector }) => {
const router = useRouter();
const searchParams = useSearchParams();

View File

@@ -1,6 +1,6 @@
"use client";
import { ValidSources } from "@/lib/types";
import { ConfigurableSources, ValidSources } from "@/lib/types";
import useSWR, { mutate } from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { FaSwatchbook } from "react-icons/fa";
@@ -35,7 +35,7 @@ export default function CredentialSection({
refresh,
}: {
ccPair: CCPairFullInfo;
sourceType: ValidSources;
sourceType: ConfigurableSources;
refresh: () => void;
}) {
const makeShowCreateCredential = () => {

View File

@@ -1,6 +1,6 @@
import React, { useState } from "react";
import { Button } from "@/components/ui/button";
import { ValidSources } from "@/lib/types";
import { ConfigurableSources } from "@/lib/types";
import { FaAccusoft } from "react-icons/fa";
import { submitCredential } from "@/components/admin/connectors/CredentialForm";
import { TextFormField } from "@/components/admin/connectors/Field";
@@ -71,7 +71,7 @@ export default function CreateCredential({
}: {
// Source information
hideSource?: boolean; // hides docs link
sourceType: ValidSources;
sourceType: ConfigurableSources;
setPopup: (popupSpec: PopupSpec | null) => void;

View File

@@ -3,7 +3,7 @@ import { Modal } from "@/components/Modal";
import { Button } from "@/components/ui/button";
import Text from "@/components/ui/text";
import { Badge } from "@/components/ui/badge";
import { ValidSources } from "@/lib/types";
import { ConfigurableSources, ValidSources } from "@/lib/types";
import {
EditIcon,
NewChatIcon,
@@ -160,7 +160,7 @@ export default function ModifyCredential({
defaultedCredential?: Credential<any>;
credentials: Credential<any>[];
editableCredentials: Credential<any>[];
source: ValidSources;
source: ConfigurableSources;
onSwitch?: (newCredential: Credential<any>) => void;
onSwap?: (newCredential: Credential<any>, connectorId: number) => void;

View File

@@ -1129,7 +1129,7 @@ export const defaultRefreshFreqMinutes = 30; // 30 minutes
// CONNECTORS
export interface ConnectorBase<T> {
name: string;
source: ValidSources;
source: ConfigurableSources;
input_type: ValidInputTypes;
connector_specific_config: T;
refresh_freq: number | null;

View File

@@ -1,4 +1,4 @@
import { ValidSources } from "../types";
import { ConfigurableSources, ValidSources } from "../types";
export interface CredentialBase<T> {
credential_json: T;
@@ -195,7 +195,7 @@ export interface FirefliesCredentialJson {
export interface MediaWikiCredentialJson {}
export interface WikipediaCredentialJson extends MediaWikiCredentialJson {}
export const credentialTemplates: Record<ValidSources, any> = {
export const credentialTemplates: Record<ConfigurableSources, any> = {
github: { github_access_token: "" } as GithubCredentialJson,
gitlab: {
gitlab_url: "",
@@ -304,8 +304,6 @@ export const credentialTemplates: Record<ValidSources, any> = {
wikipedia: null,
mediawiki: null,
web: null,
not_applicable: null,
ingestion_api: null,
// NOTE: These are Special Cases
google_drive: { google_tokens: "" } as GoogleDriveCredentialJson,

View File

@@ -40,12 +40,9 @@ import {
FirefliesIcon,
} from "@/components/icons/icons";
import { ValidSources } from "./types";
import {
DanswerDocument,
SourceCategory,
SourceMetadata,
} from "./search/interfaces";
import { SourceCategory, SourceMetadata } from "./search/interfaces";
import { Persona } from "@/app/admin/assistants/interfaces";
import { FaHammer } from "react-icons/fa";
interface PartialSourceMetadata {
icon: React.FC<{ size?: number; className?: string }>;
@@ -59,6 +56,12 @@ type SourceMap = {
};
const SOURCE_METADATA_MAP: SourceMap = {
custom_tool: {
icon: FaHammer,
displayName: "Custom Tool",
category: SourceCategory.Other,
docs: "https://docs.danswer.dev/connectors/custom_tool",
},
web: {
icon: GlobeIcon,
displayName: "Web",

View File

@@ -254,6 +254,7 @@ export interface UserGroup {
}
const validSources = [
"custom_tool",
"web",
"github",
"gitlab",
@@ -301,9 +302,11 @@ export type ValidSources = (typeof validSources)[number];
// The valid sources that are actually valid to select in the UI
export type ConfigurableSources = Exclude<
ValidSources,
"not_applicable" | "ingestion_api"
"not_applicable" | "ingestion_api" | "custom_tool"
>;
export const nonSearchableSources: ValidSources[] = ["custom_tool"];
// The sources that have auto-sync support on the backend
export const validAutoSyncSources = [
"confluence",