mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 13:45:44 +00:00
Compare commits
2 Commits
embed_imag
...
experiment
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c545819aa6 | ||
|
|
960ee228bf |
@@ -218,59 +218,76 @@ class ACPExecClient:
|
||||
buffer = ""
|
||||
packet_logger = get_packet_logger()
|
||||
|
||||
while not self._stop_reader.is_set():
|
||||
if self._ws_client is None:
|
||||
break
|
||||
|
||||
try:
|
||||
if self._ws_client.is_open():
|
||||
# Read available data
|
||||
self._ws_client.update(timeout=0.1)
|
||||
|
||||
# Read stdout (channel 1)
|
||||
data = self._ws_client.read_stdout(timeout=0.1)
|
||||
if data:
|
||||
buffer += data
|
||||
|
||||
# Process complete lines
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
message = json.loads(line)
|
||||
# Log the raw incoming message
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
packet_logger.log_raw(
|
||||
"JSONRPC-PARSE-ERROR-K8S",
|
||||
{
|
||||
"raw_line": line[:500],
|
||||
"error": "JSON decode failed",
|
||||
},
|
||||
)
|
||||
logger.warning(
|
||||
f"Invalid JSON from agent: {line[:100]}"
|
||||
)
|
||||
|
||||
else:
|
||||
packet_logger.log_raw(
|
||||
"K8S-WEBSOCKET-CLOSED",
|
||||
{"pod": self._pod_name, "namespace": self._namespace},
|
||||
)
|
||||
try:
|
||||
while not self._stop_reader.is_set():
|
||||
if self._ws_client is None:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if not self._stop_reader.is_set():
|
||||
packet_logger.log_raw(
|
||||
"K8S-READER-ERROR",
|
||||
{"error": str(e), "pod": self._pod_name},
|
||||
try:
|
||||
if self._ws_client.is_open():
|
||||
# Read available data
|
||||
self._ws_client.update(timeout=0.1)
|
||||
|
||||
# Read stdout (channel 1)
|
||||
data = self._ws_client.read_stdout(timeout=0.1)
|
||||
if data:
|
||||
buffer += data
|
||||
|
||||
# Process complete lines
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
message = json.loads(line)
|
||||
# Log the raw incoming message
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
packet_logger.log_raw(
|
||||
"JSONRPC-PARSE-ERROR-K8S",
|
||||
{
|
||||
"raw_line": line[:500],
|
||||
"error": "JSON decode failed",
|
||||
},
|
||||
)
|
||||
logger.warning(
|
||||
f"Invalid JSON from agent: {line[:100]}"
|
||||
)
|
||||
|
||||
else:
|
||||
packet_logger.log_raw(
|
||||
"K8S-WEBSOCKET-CLOSED",
|
||||
{"pod": self._pod_name, "namespace": self._namespace},
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if not self._stop_reader.is_set():
|
||||
packet_logger.log_raw(
|
||||
"K8S-READER-ERROR",
|
||||
{"error": str(e), "pod": self._pod_name},
|
||||
)
|
||||
logger.debug(f"Reader error: {e}")
|
||||
break
|
||||
finally:
|
||||
# Flush any remaining data in buffer (e.g., PromptResponse without
|
||||
# trailing newline when the WebSocket closes)
|
||||
remaining = buffer.strip()
|
||||
if remaining:
|
||||
try:
|
||||
message = json.loads(remaining)
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s-flush"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
packet_logger.log_raw(
|
||||
"K8S-BUFFER-FLUSH-FAILED",
|
||||
{"remaining": remaining[:500]},
|
||||
)
|
||||
logger.debug(f"Reader error: {e}")
|
||||
break
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the exec session and clean up."""
|
||||
@@ -465,7 +482,63 @@ class ACPExecClient:
|
||||
try:
|
||||
message_data = self._response_queue.get(timeout=min(remaining, 1.0))
|
||||
last_event_time = time.time() # Reset keepalive timer on event
|
||||
|
||||
# Diagnostic: log every dequeued message with id comparison
|
||||
msg_id = message_data.get("id")
|
||||
msg_method = message_data.get("method")
|
||||
msg_keys = list(message_data.keys())
|
||||
logger.debug(
|
||||
f"[ACP-DIAG] Dequeued message: id={msg_id} (type={type(msg_id).__name__}), "
|
||||
f"method={msg_method}, keys={msg_keys}, "
|
||||
f"request_id={request_id} (type={type(request_id).__name__}), "
|
||||
f"id_match={msg_id == request_id}"
|
||||
)
|
||||
except Empty:
|
||||
# Check if reader thread is still alive (equivalent to
|
||||
# process.poll() in the local agent client). If the reader
|
||||
# thread died, the WebSocket connection is gone and no more
|
||||
# data will arrive — break instead of emitting keepalives
|
||||
# forever.
|
||||
if (
|
||||
self._reader_thread is not None
|
||||
and not self._reader_thread.is_alive()
|
||||
):
|
||||
# Drain any final messages the reader flushed before dying
|
||||
while not self._response_queue.empty():
|
||||
try:
|
||||
final_msg = self._response_queue.get_nowait()
|
||||
if final_msg.get("id") == request_id:
|
||||
if "error" in final_msg:
|
||||
error_data = final_msg["error"]
|
||||
yield Error(
|
||||
code=error_data.get("code", -1),
|
||||
message=error_data.get(
|
||||
"message", "Unknown error"
|
||||
),
|
||||
)
|
||||
else:
|
||||
result = final_msg.get("result", {})
|
||||
try:
|
||||
yield PromptResponse.model_validate(result)
|
||||
except ValidationError:
|
||||
pass
|
||||
break
|
||||
except Empty:
|
||||
break
|
||||
|
||||
packet_logger.log_raw(
|
||||
"ACP-CONNECTION-LOST-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"events_yielded": events_yielded,
|
||||
},
|
||||
)
|
||||
logger.warning(
|
||||
f"Reader thread died for session {session_id}, "
|
||||
"ending message stream"
|
||||
)
|
||||
break
|
||||
|
||||
# Check if we need to send an SSE keepalive
|
||||
idle_time = time.time() - last_event_time
|
||||
if idle_time >= SSE_KEEPALIVE_INTERVAL:
|
||||
@@ -480,8 +553,21 @@ class ACPExecClient:
|
||||
last_event_time = time.time() # Reset after yielding keepalive
|
||||
continue
|
||||
|
||||
# Check for response to our prompt request
|
||||
if message_data.get("id") == request_id:
|
||||
# Check for response to our prompt request.
|
||||
# A JSON-RPC response has "id" but no "method" field.
|
||||
# Use str() comparison as a fallback — some ACP servers may echo
|
||||
# the id back as a string even though we sent it as an integer.
|
||||
msg_id = message_data.get("id")
|
||||
is_response = "method" not in message_data and (
|
||||
msg_id == request_id
|
||||
or (msg_id is not None and str(msg_id) == str(request_id))
|
||||
)
|
||||
if is_response and msg_id != request_id:
|
||||
logger.warning(
|
||||
f"[ACP] ID type mismatch: got {type(msg_id).__name__}({msg_id}), "
|
||||
f"expected {type(request_id).__name__}({request_id})"
|
||||
)
|
||||
if is_response:
|
||||
if "error" in message_data:
|
||||
error_data = message_data["error"]
|
||||
packet_logger.log_jsonrpc_response(
|
||||
@@ -533,12 +619,31 @@ class ACPExecClient:
|
||||
context="k8s",
|
||||
)
|
||||
|
||||
prompt_complete = False
|
||||
for event in self._process_session_update(update):
|
||||
events_yielded += 1
|
||||
# Log each yielded event
|
||||
event_type = self._get_event_type_name(event)
|
||||
packet_logger.log_acp_event_yielded(event_type, event)
|
||||
yield event
|
||||
# If PromptResponse arrived via notification, break
|
||||
if isinstance(event, PromptResponse):
|
||||
prompt_complete = True
|
||||
break
|
||||
|
||||
if prompt_complete:
|
||||
# Log completion summary
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
packet_logger.log_raw(
|
||||
"ACP-SEND-MESSAGE-COMPLETE-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"events_yielded": events_yielded,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"via": "session_update_notification",
|
||||
},
|
||||
)
|
||||
break
|
||||
|
||||
# Handle requests from agent - send error response
|
||||
elif "method" in message_data and "id" in message_data:
|
||||
@@ -552,6 +657,15 @@ class ACPExecClient:
|
||||
f"Method not supported: {message_data['method']}",
|
||||
)
|
||||
|
||||
else:
|
||||
# Message didn't match any handler — silently dropped
|
||||
logger.warning(
|
||||
f"[ACP-DIAG] Dropped message: id={message_data.get('id')}, "
|
||||
f"method={message_data.get('method')}, "
|
||||
f"keys={list(message_data.keys())}, "
|
||||
f"request_id={request_id}"
|
||||
)
|
||||
|
||||
def _get_event_type_name(self, event: ACPEvent) -> str:
|
||||
"""Get the type name for an ACP event."""
|
||||
if isinstance(event, AgentMessageChunk):
|
||||
@@ -641,6 +755,20 @@ class ACPExecClient:
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "prompt_response":
|
||||
# Some ACP versions send PromptResponse as a session/update notification
|
||||
# rather than (or in addition to) a JSON-RPC response.
|
||||
logger.info(
|
||||
"[ACP] Received prompt_response via session/update notification"
|
||||
)
|
||||
try:
|
||||
yield PromptResponse.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "available_commands_update":
|
||||
# Skip command updates
|
||||
packet_logger.log_raw(
|
||||
|
||||
@@ -353,6 +353,10 @@ class KubernetesSandboxManager(SandboxManager):
|
||||
self._agent_instructions_template_path = build_dir / "AGENTS.template.md"
|
||||
self._skills_path = Path(__file__).parent / "docker" / "skills"
|
||||
|
||||
# Track ACP exec clients in memory - keyed by (sandbox_id, session_id) tuple
|
||||
# Each session within a sandbox has its own ACP client (WebSocket connection)
|
||||
self._acp_clients: dict[tuple[UUID, UUID], ACPExecClient] = {}
|
||||
|
||||
logger.info(
|
||||
f"KubernetesSandboxManager initialized: "
|
||||
f"namespace={self._namespace}, image={self._image}"
|
||||
@@ -1161,6 +1165,20 @@ done
|
||||
Args:
|
||||
sandbox_id: The sandbox ID to terminate
|
||||
"""
|
||||
# Stop all ACP clients for this sandbox (keyed by (sandbox_id, session_id))
|
||||
clients_to_stop = [
|
||||
(key, cl) for key, cl in self._acp_clients.items() if key[0] == sandbox_id
|
||||
]
|
||||
for key, cl in clients_to_stop:
|
||||
try:
|
||||
cl.stop()
|
||||
del self._acp_clients[key]
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to stop ACP client for sandbox {sandbox_id}, "
|
||||
f"session {key[1]}: {e}"
|
||||
)
|
||||
|
||||
# Clean up Kubernetes resources (needs string for pod/service names)
|
||||
self._cleanup_kubernetes_resources(str(sandbox_id))
|
||||
|
||||
@@ -1403,6 +1421,18 @@ echo "Session workspace setup complete"
|
||||
nextjs_port: Optional port where Next.js server is running (unused in K8s,
|
||||
we use PID file instead)
|
||||
"""
|
||||
# Stop ACP client for this session
|
||||
client_key = (sandbox_id, session_id)
|
||||
acp_client = self._acp_clients.pop(client_key, None)
|
||||
if acp_client:
|
||||
try:
|
||||
acp_client.stop()
|
||||
logger.debug(f"Stopped ACP client for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to stop ACP client for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
|
||||
@@ -1830,24 +1860,41 @@ echo "Session config regeneration complete"
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
|
||||
# Log ACP client creation
|
||||
packet_logger.log_acp_client_start(
|
||||
sandbox_id, session_id, session_path, context="k8s"
|
||||
)
|
||||
# Get or create ACP client for this session
|
||||
client_key = (sandbox_id, session_id)
|
||||
client = self._acp_clients.get(client_key)
|
||||
|
||||
exec_client = ACPExecClient(
|
||||
pod_name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
)
|
||||
if client is None or not client.is_running:
|
||||
# Clean up stale client if it exists but is no longer running
|
||||
if client is not None:
|
||||
try:
|
||||
client.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Log ACP client creation
|
||||
packet_logger.log_acp_client_start(
|
||||
sandbox_id, session_id, session_path, context="k8s"
|
||||
)
|
||||
logger.info(
|
||||
f"Creating new ACP client for sandbox {sandbox_id}, session {session_id}"
|
||||
)
|
||||
|
||||
# Create and start ACP client for this session
|
||||
client = ACPExecClient(
|
||||
pod_name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
)
|
||||
client.start(cwd=session_path)
|
||||
self._acp_clients[client_key] = client
|
||||
|
||||
# Log the send_message call at sandbox manager level
|
||||
packet_logger.log_session_start(session_id, sandbox_id, message)
|
||||
|
||||
events_count = 0
|
||||
try:
|
||||
exec_client.start(cwd=session_path)
|
||||
for event in exec_client.send_message(message):
|
||||
for event in client.send_message(message):
|
||||
events_count += 1
|
||||
yield event
|
||||
|
||||
@@ -1884,10 +1931,6 @@ echo "Session config regeneration complete"
|
||||
events_count=events_count,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
exec_client.stop()
|
||||
# Log client stop
|
||||
packet_logger.log_acp_client_stop(sandbox_id, session_id, context="k8s")
|
||||
|
||||
def list_directory(
|
||||
self, sandbox_id: UUID, session_id: UUID, path: str
|
||||
|
||||
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""Unit tests for ACPExecClient caching behavior in KubernetesSandboxManager.
|
||||
|
||||
These tests verify that the KubernetesSandboxManager correctly caches
|
||||
ACPExecClient instances per (sandbox_id, session_id) pair, reuses them
|
||||
across send_message calls, replaces dead clients, and cleans them up
|
||||
on terminate/cleanup.
|
||||
|
||||
All external dependencies (K8s, WebSockets, packet logging) are mocked.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The fully-qualified path to the module under test, used for patching
|
||||
_K8S_MODULE = "onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager"
|
||||
_ACP_CLIENT_CLASS = f"{_K8S_MODULE}.ACPExecClient"
|
||||
_GET_PACKET_LOGGER = f"{_K8S_MODULE}.get_packet_logger"
|
||||
|
||||
|
||||
def _make_mock_event() -> MagicMock:
|
||||
"""Create a mock ACP event."""
|
||||
return MagicMock(name="mock_acp_event")
|
||||
|
||||
|
||||
def _make_mock_client(is_running: bool = True) -> MagicMock:
|
||||
"""Create a mock ACPExecClient with configurable is_running property."""
|
||||
mock_client = MagicMock()
|
||||
type(mock_client).is_running = property(lambda _self: is_running)
|
||||
mock_client.start.return_value = "mock-session-id"
|
||||
mock_event = _make_mock_event()
|
||||
mock_client.send_message.return_value = iter([mock_event])
|
||||
mock_client.stop.return_value = None
|
||||
return mock_client
|
||||
|
||||
|
||||
def _drain_generator(gen: Generator[Any, None, None]) -> list[Any]:
|
||||
"""Consume a generator and return all yielded values as a list."""
|
||||
return list(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture: fresh KubernetesSandboxManager instance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def manager() -> Generator[Any, None, None]:
|
||||
"""Create a fresh KubernetesSandboxManager instance with all externals mocked.
|
||||
|
||||
This fixture:
|
||||
1. Resets the singleton _instance so each test gets a fresh manager
|
||||
2. Mocks kubernetes.config and kubernetes.client to prevent real K8s calls
|
||||
3. Mocks get_packet_logger to prevent logging side effects
|
||||
"""
|
||||
# Import here so patches are in effect when the class loads
|
||||
with (
|
||||
patch(f"{_K8S_MODULE}.config") as _mock_config,
|
||||
patch(f"{_K8S_MODULE}.client") as _mock_k8s_client,
|
||||
patch(f"{_K8S_MODULE}.k8s_stream"),
|
||||
patch(_GET_PACKET_LOGGER) as mock_get_logger,
|
||||
):
|
||||
# Set up the mock packet logger
|
||||
mock_packet_logger = MagicMock()
|
||||
mock_get_logger.return_value = mock_packet_logger
|
||||
|
||||
# Make config.load_incluster_config succeed (no-op)
|
||||
_mock_config.load_incluster_config.return_value = None
|
||||
_mock_config.ConfigException = Exception
|
||||
|
||||
# Make client constructors return mocks
|
||||
_mock_k8s_client.ApiClient.return_value = MagicMock()
|
||||
_mock_k8s_client.CoreV1Api.return_value = MagicMock()
|
||||
_mock_k8s_client.BatchV1Api.return_value = MagicMock()
|
||||
_mock_k8s_client.NetworkingV1Api.return_value = MagicMock()
|
||||
|
||||
# Reset singleton before importing
|
||||
from onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager import (
|
||||
KubernetesSandboxManager,
|
||||
)
|
||||
|
||||
KubernetesSandboxManager._instance = None
|
||||
|
||||
mgr = KubernetesSandboxManager()
|
||||
|
||||
# Ensure the _acp_clients dict exists (it should be initialized by
|
||||
# the caching implementation)
|
||||
if not hasattr(mgr, "_acp_clients"):
|
||||
mgr._acp_clients = {}
|
||||
|
||||
yield mgr
|
||||
|
||||
# Reset singleton after test
|
||||
KubernetesSandboxManager._instance = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_send_message_creates_client_on_first_call(manager: Any) -> None:
|
||||
"""First call to send_message() should create a new ACPExecClient,
|
||||
call start(), cache it, and yield events from send_message()."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
message = "hello world"
|
||||
|
||||
mock_event = _make_mock_event()
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
mock_client.send_message.return_value = iter([mock_event])
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
|
||||
events = _drain_generator(manager.send_message(sandbox_id, session_id, message))
|
||||
|
||||
# Verify client was constructed
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# Verify start() was called with the correct session path
|
||||
expected_cwd = f"/workspace/sessions/{session_id}"
|
||||
mock_client.start.assert_called_once_with(cwd=expected_cwd)
|
||||
|
||||
# Verify send_message was called on the client
|
||||
mock_client.send_message.assert_called_once_with(message)
|
||||
|
||||
# Verify we got the event
|
||||
assert len(events) >= 1
|
||||
# Find our mock event (filter out any SSEKeepalive or similar)
|
||||
assert mock_event in events
|
||||
|
||||
# Verify client was cached
|
||||
client_key = (sandbox_id, session_id)
|
||||
assert client_key in manager._acp_clients
|
||||
assert manager._acp_clients[client_key] is mock_client
|
||||
|
||||
|
||||
def test_send_message_reuses_cached_client(manager: Any) -> None:
|
||||
"""Second call with the same (sandbox_id, session_id) should NOT create
|
||||
a new client. Should reuse the cached one."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
message_1 = "first message"
|
||||
message_2 = "second message"
|
||||
|
||||
mock_event_1 = _make_mock_event()
|
||||
mock_event_2 = _make_mock_event()
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
|
||||
# send_message returns different events for each call
|
||||
mock_client.send_message.side_effect = [
|
||||
iter([mock_event_1]),
|
||||
iter([mock_event_2]),
|
||||
]
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
|
||||
events_1 = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id, message_1)
|
||||
)
|
||||
events_2 = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id, message_2)
|
||||
)
|
||||
|
||||
# Constructor called only ONCE (on first send_message)
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# start() called only once
|
||||
mock_client.start.assert_called_once()
|
||||
|
||||
# send_message called twice with different messages
|
||||
assert mock_client.send_message.call_count == 2
|
||||
mock_client.send_message.assert_any_call(message_1)
|
||||
mock_client.send_message.assert_any_call(message_2)
|
||||
|
||||
# Both calls yielded events
|
||||
assert mock_event_1 in events_1
|
||||
assert mock_event_2 in events_2
|
||||
|
||||
|
||||
def test_send_message_replaces_dead_client(manager: Any) -> None:
|
||||
"""If cached client has is_running == False, should create a new one,
|
||||
start it, and cache the replacement."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
message = "test message"
|
||||
|
||||
# Create a dead client (is_running = False) and place it in the cache
|
||||
dead_client = _make_mock_client(is_running=False)
|
||||
client_key = (sandbox_id, session_id)
|
||||
manager._acp_clients[client_key] = dead_client
|
||||
|
||||
# Create the replacement client
|
||||
new_event = _make_mock_event()
|
||||
new_client = _make_mock_client(is_running=True)
|
||||
new_client.send_message.return_value = iter([new_event])
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=new_client) as MockClass:
|
||||
events = _drain_generator(manager.send_message(sandbox_id, session_id, message))
|
||||
|
||||
# A new client was constructed (the dead one was replaced)
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# New client was started and used
|
||||
new_client.start.assert_called_once()
|
||||
new_client.send_message.assert_called_once_with(message)
|
||||
|
||||
# Cache now holds the new client
|
||||
assert manager._acp_clients[client_key] is new_client
|
||||
|
||||
# Events from new client were yielded
|
||||
assert new_event in events
|
||||
|
||||
|
||||
def test_send_message_different_sessions_get_different_clients(
|
||||
manager: Any,
|
||||
) -> None:
|
||||
"""Two calls with different session_id values should create two
|
||||
separate clients, each cached under its own key."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id_a: UUID = uuid4()
|
||||
session_id_b: UUID = uuid4()
|
||||
message = "test"
|
||||
|
||||
mock_client_a = _make_mock_client(is_running=True)
|
||||
mock_client_b = _make_mock_client(is_running=True)
|
||||
|
||||
mock_event_a = _make_mock_event()
|
||||
mock_event_b = _make_mock_event()
|
||||
mock_client_a.send_message.return_value = iter([mock_event_a])
|
||||
mock_client_b.send_message.return_value = iter([mock_event_b])
|
||||
|
||||
with patch(
|
||||
_ACP_CLIENT_CLASS, side_effect=[mock_client_a, mock_client_b]
|
||||
) as MockClass:
|
||||
events_a = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id_a, message)
|
||||
)
|
||||
events_b = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id_b, message)
|
||||
)
|
||||
|
||||
# Two separate clients were constructed
|
||||
assert MockClass.call_count == 2
|
||||
|
||||
# Both were started
|
||||
mock_client_a.start.assert_called_once()
|
||||
mock_client_b.start.assert_called_once()
|
||||
|
||||
# Each is cached under a different key
|
||||
assert manager._acp_clients[(sandbox_id, session_id_a)] is mock_client_a
|
||||
assert manager._acp_clients[(sandbox_id, session_id_b)] is mock_client_b
|
||||
|
||||
# Events from each client are correct
|
||||
assert mock_event_a in events_a
|
||||
assert mock_event_b in events_b
|
||||
|
||||
|
||||
def test_terminate_stops_all_sandbox_clients(manager: Any) -> None:
|
||||
"""terminate(sandbox_id) should stop all cached clients for that
|
||||
sandbox and remove them from the cache."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id_1: UUID = uuid4()
|
||||
session_id_2: UUID = uuid4()
|
||||
|
||||
client_1 = _make_mock_client(is_running=True)
|
||||
client_2 = _make_mock_client(is_running=True)
|
||||
|
||||
manager._acp_clients[(sandbox_id, session_id_1)] = client_1
|
||||
manager._acp_clients[(sandbox_id, session_id_2)] = client_2
|
||||
|
||||
# Mock _cleanup_kubernetes_resources to prevent actual K8s calls
|
||||
with patch.object(manager, "_cleanup_kubernetes_resources"):
|
||||
manager.terminate(sandbox_id)
|
||||
|
||||
# Both clients should have been stopped
|
||||
client_1.stop.assert_called_once()
|
||||
client_2.stop.assert_called_once()
|
||||
|
||||
# Both should be removed from cache
|
||||
assert (sandbox_id, session_id_1) not in manager._acp_clients
|
||||
assert (sandbox_id, session_id_2) not in manager._acp_clients
|
||||
|
||||
|
||||
def test_terminate_leaves_other_sandbox_clients(manager: Any) -> None:
|
||||
"""terminate(sandbox_id_A) should NOT affect clients cached for
|
||||
sandbox_id_B."""
|
||||
sandbox_id_a: UUID = uuid4()
|
||||
sandbox_id_b: UUID = uuid4()
|
||||
session_id_a: UUID = uuid4()
|
||||
session_id_b: UUID = uuid4()
|
||||
|
||||
client_a = _make_mock_client(is_running=True)
|
||||
client_b = _make_mock_client(is_running=True)
|
||||
|
||||
manager._acp_clients[(sandbox_id_a, session_id_a)] = client_a
|
||||
manager._acp_clients[(sandbox_id_b, session_id_b)] = client_b
|
||||
|
||||
# Terminate only sandbox A
|
||||
with patch.object(manager, "_cleanup_kubernetes_resources"):
|
||||
manager.terminate(sandbox_id_a)
|
||||
|
||||
# Client A stopped and removed
|
||||
client_a.stop.assert_called_once()
|
||||
assert (sandbox_id_a, session_id_a) not in manager._acp_clients
|
||||
|
||||
# Client B untouched
|
||||
client_b.stop.assert_not_called()
|
||||
assert (sandbox_id_b, session_id_b) in manager._acp_clients
|
||||
assert manager._acp_clients[(sandbox_id_b, session_id_b)] is client_b
|
||||
|
||||
|
||||
def test_cleanup_session_stops_session_client(manager: Any) -> None:
|
||||
"""cleanup_session_workspace(sandbox_id, session_id) should stop and
|
||||
remove the specific session's client from the cache."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
cached_client = _make_mock_client(is_running=True)
|
||||
manager._acp_clients[(sandbox_id, session_id)] = cached_client
|
||||
|
||||
# Mock the k8s exec call that runs the cleanup script
|
||||
with patch.object(manager, "_stream_core_api") as mock_stream_api:
|
||||
mock_stream_api.connect_get_namespaced_pod_exec = MagicMock()
|
||||
with patch(f"{_K8S_MODULE}.k8s_stream", return_value="cleanup ok"):
|
||||
manager.cleanup_session_workspace(sandbox_id, session_id)
|
||||
|
||||
# Client should have been stopped
|
||||
cached_client.stop.assert_called_once()
|
||||
|
||||
# Client should be removed from the cache
|
||||
assert (sandbox_id, session_id) not in manager._acp_clients
|
||||
|
||||
|
||||
def test_cleanup_session_handles_no_cached_client(manager: Any) -> None:
|
||||
"""cleanup_session_workspace() should not error when there's no cached
|
||||
client for that session."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
# Ensure no client is cached for this pair
|
||||
assert (sandbox_id, session_id) not in manager._acp_clients
|
||||
|
||||
# Mock the k8s exec call that runs the cleanup script
|
||||
with patch.object(manager, "_stream_core_api") as mock_stream_api:
|
||||
mock_stream_api.connect_get_namespaced_pod_exec = MagicMock()
|
||||
with patch(f"{_K8S_MODULE}.k8s_stream", return_value="cleanup ok"):
|
||||
# This should NOT raise
|
||||
manager.cleanup_session_workspace(sandbox_id, session_id)
|
||||
|
||||
# Cache is still empty for this key
|
||||
assert (sandbox_id, session_id) not in manager._acp_clients
|
||||
Reference in New Issue
Block a user