Compare commits

...

2 Commits

Author SHA1 Message Date
Weves
4fb504e413 testing 2024-12-09 08:20:49 -08:00
Weves
5d6973b873 test 2024-12-09 08:20:49 -08:00
7 changed files with 64 additions and 8 deletions

View File

@@ -138,6 +138,8 @@ class Answer:
) -> AnswerStream:
current_llm_call = llm_calls[-1]
logger.info("Decided on tool call")
# make a dummy tool handler
tool_handler = ToolResponseHandler([tool])

View File

@@ -1,3 +1,4 @@
import time
import traceback
from collections.abc import Callable
from collections.abc import Iterator
@@ -315,6 +316,8 @@ def stream_chat_message_objects(
try:
user_id = user.id if user is not None else None
start = time.monotonic()
chat_session = get_chat_session_by_id(
chat_session_id=new_msg_req.chat_session_id,
user_id=user_id,
@@ -652,6 +655,9 @@ def stream_chat_message_objects(
for tool_list in tool_dict.values():
tools.extend(tool_list)
enter_answer = time.monotonic()
logger.debug(f"Enter answer: {enter_answer - start}")
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,

View File

@@ -138,8 +138,10 @@ class ToolResponseHandler:
logger.info(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"])
logger.info("Kicking off tool")
self.tool_kickoff = self.tool_runner.kickoff()
yield self.tool_kickoff
logger.info("Yielded tool kickoff")
for response in self.tool_runner.tool_responses():
self.tool_responses.append(response)

View File

@@ -280,6 +280,7 @@ class SearchTool(Tool):
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["query"])
logger.info(f"Running search tool with query: {query}")
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
@@ -315,6 +316,7 @@ class SearchTool(Tool):
prompt_config=self.prompt_config,
)
logger.info("Yielding initial search response")
yield ToolResponse(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=SearchResponseSummary(
@@ -326,6 +328,7 @@ class SearchTool(Tool):
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
),
)
logger.info("Yielding search doc content")
yield ToolResponse(
id=SEARCH_DOC_CONTENT_ID,

View File

@@ -157,7 +157,7 @@ def get_standard_formatter() -> ColoredFormatter:
"""Returns a standard colored logging formatter."""
return ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
# datefmt="%m/%d/%Y %I:%M:%S %p",
)

View File

@@ -42,7 +42,7 @@ class LongTermLogger:
# Delete oldest files that exceed the limit
for file in files[self.max_files_per_category :]:
try:
file.unlink()
file.unlink(missing_ok=True)
except Exception as e:
logger.error(f"Error deleting old log file {file}: {e}")
except Exception as e:

View File

@@ -36,8 +36,11 @@ class ChatMetrics:
total_time: float
first_doc_time: float
first_answer_time: float
tool_selection_time: float
tokens_per_second: float
total_tokens: int
request_number: int
message: str
class ChatLoadTester:
@@ -47,11 +50,13 @@ class ChatLoadTester:
api_key: str | None,
num_concurrent: int,
messages_per_session: int,
persona_id: int = 0,
):
self.base_url = base_url
self.headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
self.num_concurrent = num_concurrent
self.messages_per_session = messages_per_session
self.persona_id = persona_id
self.metrics: list[ChatMetrics] = []
async def create_chat_session(self, session: aiohttp.ClientSession) -> str:
@@ -59,18 +64,21 @@ class ChatLoadTester:
async with session.post(
f"{self.base_url}/chat/create-chat-session",
headers=self.headers,
json={"persona_id": 0, "description": "Load Test"},
json={"persona_id": self.persona_id, "description": "Load Test"},
) as response:
response.raise_for_status()
data = await response.json()
return data["chat_session_id"]
async def process_stream(
self, response: aiohttp.ClientResponse
self,
response: aiohttp.ClientResponse,
chat_session_id: str,
) -> AsyncGenerator[str, None]:
"""Process the SSE stream from the chat response"""
async for chunk in response.content:
chunk_str = chunk.decode()
logger.info(f"Session {chat_session_id}: {chunk_str}")
yield chunk_str
async def send_message(
@@ -78,12 +86,14 @@ class ChatLoadTester:
session: aiohttp.ClientSession,
chat_session_id: str,
message: str,
request_number: int,
parent_message_id: int | None = None,
) -> ChatMetrics:
"""Send a message and measure performance metrics"""
start_time = time.time()
first_doc_time = None
first_answer_time = None
tool_selection_time = None
token_count = 0
async with session.post(
@@ -104,8 +114,11 @@ class ChatLoadTester:
) as response:
response.raise_for_status()
async for chunk in self.process_stream(response):
if "tool_name" in chunk and "run_search" in chunk:
async for chunk in self.process_stream(response, chat_session_id):
if '"tool_name"' in chunk and tool_selection_time is None:
tool_selection_time = time.time() - start_time
if '{"top_documents"' in chunk:
if first_doc_time is None:
first_doc_time = time.time() - start_time
@@ -122,8 +135,11 @@ class ChatLoadTester:
total_time=total_time,
first_doc_time=first_doc_time or 0,
first_answer_time=first_answer_time or 0,
tool_selection_time=tool_selection_time or 0,
tokens_per_second=tokens_per_second,
total_tokens=token_count,
request_number=request_number,
message=message,
)
async def run_chat_session(self) -> None:
@@ -143,10 +159,10 @@ class ChatLoadTester:
for i in range(self.messages_per_session):
message = messages[i % len(messages)]
metrics = await self.send_message(
session, chat_session_id, message, parent_message_id
session, chat_session_id, message, i + 1, parent_message_id
)
self.metrics.append(metrics)
parent_message_id = metrics.total_tokens # Simplified for example
parent_message_id = metrics.total_tokens
except Exception as e:
logger.error(f"Error in chat session: {e}")
@@ -174,15 +190,35 @@ class ChatLoadTester:
avg_first_answer = statistics.mean(
m.first_answer_time for m in self.metrics
)
avg_tool_selection = statistics.mean(
m.tool_selection_time for m in self.metrics
)
avg_tokens_per_sec = statistics.mean(
m.tokens_per_second for m in self.metrics
)
logger.info(f"\nAverage Response Time: {avg_response_time:.2f} seconds")
logger.info(
f"Average Time to Tool Selection: {avg_tool_selection:.2f} seconds"
)
logger.info(f"Average Time to Documents: {avg_first_doc:.2f} seconds")
logger.info(f"Average Time to First Answer: {avg_first_answer:.2f} seconds")
logger.info(f"Average Tokens/Second: {avg_tokens_per_sec:.2f}")
logger.info("\nIndividual Request Times:")
sorted_metrics = sorted(
self.metrics, key=lambda m: (m.session_id, m.request_number)
)
for m in sorted_metrics:
logger.info(
f"Session {m.session_id} - Request {m.request_number} - "
f"Message: '{m.message[:30]}...' - "
f"Total Time: {m.total_time:.2f}s, "
f"Tool Selection: {m.tool_selection_time:.2f}s, "
f"Doc Time: {m.first_doc_time:.2f}s, "
f"First Answer: {m.first_answer_time:.2f}s"
)
def main() -> None:
parser = argparse.ArgumentParser(description="Chat Load Testing Tool")
@@ -209,6 +245,12 @@ def main() -> None:
default=1,
help="Number of messages per chat session",
)
parser.add_argument(
"--persona-id",
type=int,
default=0,
help="Persona ID to use for chat sessions",
)
args = parser.parse_args()
@@ -217,6 +259,7 @@ def main() -> None:
api_key=args.api_key,
num_concurrent=args.concurrent,
messages_per_session=args.messages,
persona_id=args.persona_id,
)
asyncio.run(load_tester.run_load_test())