Compare commits

...

5 Commits

Author SHA1 Message Date
pablodanswer
f81dcfd179 minor method spec update 2024-09-20 18:33:03 -07:00
pablodanswer
ba6ab363f4 update tests - additional coverage 2024-09-20 18:11:13 -07:00
pablodanswer
a573ba6fb3 update ports 2024-09-20 18:04:26 -07:00
pablodanswer
ffc5dd7b49 add custom tool testing 2024-09-20 18:04:26 -07:00
pablodanswer
c90d7da02d add initial testing 2024-09-20 18:04:26 -07:00
7 changed files with 227 additions and 4 deletions

View File

@@ -135,7 +135,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
# defaults to False

View File

@@ -204,6 +204,7 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
]

View File

@@ -16,6 +16,7 @@ class MethodSpec(BaseModel):
summary: str
path: str
method: str
body_schema: dict[str, Any] = {}
spec: dict[str, Any]
def get_request_body_schema(self) -> dict[str, Any]:
@@ -87,6 +88,8 @@ class MethodSpec(BaseModel):
tool_definition["function"]["parameters"]["properties"].update(
{param["name"]: param["schema"] for param in path_param_schemas}
)
print(tool_definition)
print("")
return tool_definition
def validate_spec(self) -> None:

View File

@@ -0,0 +1,219 @@
import unittest
from unittest.mock import patch
import pytest
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.custom.custom_tool import validate_openapi_schema
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.tool import ToolResponse
class TestCustomTool(unittest.TestCase):
"""
Test suite for CustomTool functionality.
This class tests the creation, running, and result handling of custom tools
based on OpenAPI schemas.
"""
def setUp(self):
"""
Set up the test environment before each test method.
Initializes an OpenAPI schema and DynamicSchemaInfo for testing.
"""
self.openapi_schema = {
"openapi": "3.0.0",
"info": {
"version": "1.0.0",
"title": "Assistants API",
"description": "An API for managing assistants",
},
"servers": [
{"url": "http://localhost:8080/CHAT_SESSION_ID/test/MESSAGE_ID"},
],
"paths": {
"/assistant/{assistant_id}": {
"GET": {
"summary": "Get a specific Assistant",
"operationId": "getAssistant",
"parameters": [
{
"name": "assistant_id",
"in": "path",
"required": True,
"schema": {"type": "string"},
}
],
},
"POST": {
"summary": "Create a new Assistant",
"operationId": "createAssistant",
"parameters": [
{
"name": "assistant_id",
"in": "path",
"required": True,
"schema": {"type": "string"},
}
],
"requestBody": {
"required": True,
"content": {
"application/json": {"schema": {"type": "object"}}
},
},
},
}
},
}
validate_openapi_schema(self.openapi_schema)
self.dynamic_schema_info = DynamicSchemaInfo(chat_session_id=10, message_id=20)
@patch("danswer.tools.custom.custom_tool.requests.request")
def test_custom_tool_run_get(self, mock_request):
"""
Test the GET method of a custom tool.
Verifies that the tool correctly constructs the URL and makes the GET request.
"""
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema, dynamic_schema_info=self.dynamic_schema_info
)
result = list(tools[0].run(assistant_id="123"))
expected_url = f"http://localhost:8080/{self.dynamic_schema_info.chat_session_id}/test/{self.dynamic_schema_info.message_id}/assistant/123"
mock_request.assert_called_once_with("GET", expected_url, json=None, headers={})
self.assertEqual(
len(result), 1, "Expected exactly one result from the tool run"
)
self.assertEqual(
result[0].id,
CUSTOM_TOOL_RESPONSE_ID,
"Tool response ID does not match expected value",
)
self.assertEqual(
result[0].response.tool_name,
"getAssistant",
"Tool name in response does not match expected value",
)
@patch("danswer.tools.custom.custom_tool.requests.request")
def test_custom_tool_run_post(self, mock_request):
"""
Test the POST method of a custom tool.
Verifies that the tool correctly constructs the URL and makes the POST request with the given body.
"""
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema, dynamic_schema_info=self.dynamic_schema_info
)
result = list(tools[1].run(assistant_id="456"))
expected_url = f"http://localhost:8080/{self.dynamic_schema_info.chat_session_id}/test/{self.dynamic_schema_info.message_id}/assistant/456"
mock_request.assert_called_once_with(
"POST", expected_url, json=None, headers={}
)
self.assertEqual(
len(result), 1, "Expected exactly one result from the tool run"
)
self.assertEqual(
result[0].id,
CUSTOM_TOOL_RESPONSE_ID,
"Tool response ID does not match expected value",
)
self.assertEqual(
result[0].response.tool_name,
"createAssistant",
"Tool name in response does not match expected value",
)
@patch("danswer.tools.custom.custom_tool.requests.request")
def test_custom_tool_with_headers(self, mock_request):
"""
Test the custom tool with custom headers.
Verifies that the tool correctly includes the custom headers in the request.
"""
custom_headers = [
{"key": "Authorization", "value": "Bearer token123"},
{"key": "Custom-Header", "value": "CustomValue"},
]
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema,
custom_headers=custom_headers,
dynamic_schema_info=self.dynamic_schema_info,
)
list(tools[0].run(assistant_id="123"))
expected_url = f"http://localhost:8080/{self.dynamic_schema_info.chat_session_id}/test/{self.dynamic_schema_info.message_id}/assistant/123"
expected_headers = {
"Authorization": "Bearer token123",
"Custom-Header": "CustomValue",
}
mock_request.assert_called_once_with(
"GET", expected_url, json=None, headers=expected_headers
)
@patch("danswer.tools.custom.custom_tool.requests.request")
def test_custom_tool_with_empty_headers(self, mock_request):
"""
Test the custom tool with an empty list of custom headers.
Verifies that the tool correctly handles an empty list of headers.
"""
custom_headers = []
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema,
custom_headers=custom_headers,
dynamic_schema_info=self.dynamic_schema_info,
)
list(tools[0].run(assistant_id="123"))
expected_url = f"http://localhost:8080/{self.dynamic_schema_info.chat_session_id}/test/{self.dynamic_schema_info.message_id}/assistant/123"
mock_request.assert_called_once_with("GET", expected_url, json=None, headers={})
def test_invalid_openapi_schema(self):
"""
Test that an invalid OpenAPI schema raises a ValueError.
"""
invalid_schema = {
"openapi": "3.0.0",
"info": {
"version": "1.0.0",
"title": "Invalid API",
},
# Missing required 'paths' key
}
with self.assertRaises(ValueError) as _:
validate_openapi_schema(invalid_schema)
def test_custom_tool_final_result(self):
"""
Test the final_result method of a custom tool.
Verifies that the method correctly extracts and returns the tool result.
"""
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema, dynamic_schema_info=self.dynamic_schema_info
)
mock_response = ToolResponse(
id=CUSTOM_TOOL_RESPONSE_ID,
response=CustomToolCallSummary(
tool_name="getAssistant",
tool_result={"id": "789", "name": "Final Assistant"},
),
)
final_result = tools[0].final_result(mock_response)
self.assertEqual(
final_result,
{"id": "789", "name": "Final Assistant"},
"Final result does not match expected output",
)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -293,7 +293,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -303,7 +303,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -154,7 +154,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432"
- "5433"
volumes:
- db_volume:/var/lib/postgresql/data