mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-28 05:05:48 +00:00
Compare commits
6 Commits
embed_imag
...
experiment
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f023599618 | ||
|
|
c55cb899f7 | ||
|
|
9b8a6e60b7 | ||
|
|
dd9d201b51 | ||
|
|
c545819aa6 | ||
|
|
960ee228bf |
@@ -4,8 +4,9 @@ This client runs `opencode acp` directly in the sandbox pod via kubernetes exec,
|
||||
using stdin/stdout for JSON-RPC communication. This bypasses the HTTP server
|
||||
and uses the native ACP subprocess protocol.
|
||||
|
||||
This module includes comprehensive logging for debugging ACP communication.
|
||||
Enable logging by setting LOG_LEVEL=DEBUG or BUILD_PACKET_LOGGING=true.
|
||||
When multiple API server replicas share the same sandbox pod, this client
|
||||
uses ACP session resumption (session/list + session/resume) to maintain
|
||||
conversation context across replicas.
|
||||
|
||||
Usage:
|
||||
client = ACPExecClient(
|
||||
@@ -144,6 +145,7 @@ class ACPExecClient:
|
||||
self._reader_thread: threading.Thread | None = None
|
||||
self._stop_reader = threading.Event()
|
||||
self._k8s_client: client.CoreV1Api | None = None
|
||||
self._prompt_count: int = 0 # Track how many prompts sent on this client
|
||||
|
||||
def _get_k8s_client(self) -> client.CoreV1Api:
|
||||
"""Get or create kubernetes client."""
|
||||
@@ -176,6 +178,8 @@ class ACPExecClient:
|
||||
# Start opencode acp via exec
|
||||
exec_command = ["opencode", "acp", "--cwd", cwd]
|
||||
|
||||
logger.info(f"[ACP] Starting client: pod={self._pod_name} cwd={cwd}")
|
||||
|
||||
try:
|
||||
self._ws_client = k8s_stream(
|
||||
k8s.connect_get_namespaced_pod_exec,
|
||||
@@ -204,12 +208,25 @@ class ACPExecClient:
|
||||
# Initialize ACP connection
|
||||
self._initialize(timeout=timeout)
|
||||
|
||||
# Create session
|
||||
session_id = self._create_session(cwd=cwd, timeout=timeout)
|
||||
# Try to resume an existing session first (handles multi-replica).
|
||||
# When multiple API server replicas connect to the same sandbox
|
||||
# pod, a previous replica may have already created a session for
|
||||
# this workspace. Resuming preserves conversation context.
|
||||
session_id = self._try_resume_existing_session(cwd, timeout)
|
||||
resumed = session_id is not None
|
||||
|
||||
if not session_id:
|
||||
# No existing session found — create a new one
|
||||
session_id = self._create_session(cwd=cwd, timeout=timeout)
|
||||
|
||||
logger.info(
|
||||
f"[ACP] Client started: pod={self._pod_name} "
|
||||
f"acp_session={session_id} resumed={resumed}"
|
||||
)
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ACP] Client start failed: pod={self._pod_name} error={e}")
|
||||
self.stop()
|
||||
raise RuntimeError(f"Failed to start ACP exec client: {e}") from e
|
||||
|
||||
@@ -217,63 +234,157 @@ class ACPExecClient:
|
||||
"""Background thread to read responses from the exec stream."""
|
||||
buffer = ""
|
||||
packet_logger = get_packet_logger()
|
||||
messages_read = 0
|
||||
# Track how many consecutive read cycles the buffer has had
|
||||
# unterminated data (no trailing newline) with no new data arriving.
|
||||
buffer_stale_cycles = 0
|
||||
# Track empty read cycles for periodic buffer state logging
|
||||
empty_read_cycles = 0
|
||||
|
||||
while not self._stop_reader.is_set():
|
||||
if self._ws_client is None:
|
||||
break
|
||||
logger.debug(f"[ACP] Reader thread started for pod={self._pod_name}")
|
||||
|
||||
try:
|
||||
if self._ws_client.is_open():
|
||||
# Read available data
|
||||
self._ws_client.update(timeout=0.1)
|
||||
|
||||
# Read stdout (channel 1)
|
||||
data = self._ws_client.read_stdout(timeout=0.1)
|
||||
if data:
|
||||
buffer += data
|
||||
|
||||
# Process complete lines
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
message = json.loads(line)
|
||||
# Log the raw incoming message
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
packet_logger.log_raw(
|
||||
"JSONRPC-PARSE-ERROR-K8S",
|
||||
{
|
||||
"raw_line": line[:500],
|
||||
"error": "JSON decode failed",
|
||||
},
|
||||
)
|
||||
logger.warning(
|
||||
f"Invalid JSON from agent: {line[:100]}"
|
||||
)
|
||||
|
||||
else:
|
||||
packet_logger.log_raw(
|
||||
"K8S-WEBSOCKET-CLOSED",
|
||||
{"pod": self._pod_name, "namespace": self._namespace},
|
||||
)
|
||||
try:
|
||||
while not self._stop_reader.is_set():
|
||||
if self._ws_client is None:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if not self._stop_reader.is_set():
|
||||
packet_logger.log_raw(
|
||||
"K8S-READER-ERROR",
|
||||
{"error": str(e), "pod": self._pod_name},
|
||||
try:
|
||||
if self._ws_client.is_open():
|
||||
self._ws_client.update(timeout=0.1)
|
||||
|
||||
# Read stderr - log any agent errors
|
||||
stderr_data = self._ws_client.read_stderr(timeout=0.01)
|
||||
if stderr_data:
|
||||
logger.warning(
|
||||
f"[ACP] stderr pod={self._pod_name}: "
|
||||
f"{stderr_data.strip()[:500]}"
|
||||
)
|
||||
|
||||
# Read stdout
|
||||
data = self._ws_client.read_stdout(timeout=0.1)
|
||||
if data:
|
||||
buffer += data
|
||||
buffer_stale_cycles = 0
|
||||
empty_read_cycles = 0
|
||||
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
message = json.loads(line)
|
||||
messages_read += 1
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[ACP] Invalid JSON from agent: "
|
||||
f"{line[:100]}"
|
||||
)
|
||||
else:
|
||||
empty_read_cycles += 1
|
||||
|
||||
# No new data arrived this cycle. If the buffer
|
||||
# has unterminated content, track how long it's
|
||||
# been sitting there. After a few cycles (~0.5s)
|
||||
# try to parse it — the agent may have sent the
|
||||
# last message without a trailing newline.
|
||||
if buffer.strip():
|
||||
buffer_stale_cycles += 1
|
||||
if buffer_stale_cycles == 1:
|
||||
logger.info(
|
||||
f"[ACP] Buffer has unterminated data: "
|
||||
f"{len(buffer)} bytes, "
|
||||
f"preview={buffer.strip()[:200]}"
|
||||
)
|
||||
if buffer_stale_cycles >= 3:
|
||||
logger.info(
|
||||
f"[ACP] Attempting stale buffer parse: "
|
||||
f"{len(buffer)} bytes, "
|
||||
f"cycles={buffer_stale_cycles}"
|
||||
)
|
||||
try:
|
||||
message = json.loads(buffer.strip())
|
||||
messages_read += 1
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN",
|
||||
message,
|
||||
context="k8s-unterminated",
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
buffer = ""
|
||||
buffer_stale_cycles = 0
|
||||
logger.info(
|
||||
"[ACP] Stale buffer parsed successfully"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON yet, keep waiting
|
||||
logger.debug(
|
||||
f"[ACP] Stale buffer not valid JSON: "
|
||||
f"{buffer.strip()[:100]}"
|
||||
)
|
||||
|
||||
# Periodic log: every ~5s (50 cycles at 0.1s each)
|
||||
# when we're idle with an empty buffer — helps
|
||||
# confirm the reader is alive and waiting.
|
||||
if empty_read_cycles % 50 == 0:
|
||||
logger.info(
|
||||
f"[ACP] Reader idle: "
|
||||
f"empty_cycles={empty_read_cycles} "
|
||||
f"buffer={len(buffer)} bytes "
|
||||
f"messages_read={messages_read} "
|
||||
f"pod={self._pod_name}"
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"[ACP] WebSocket closed: pod={self._pod_name}, "
|
||||
f"messages_read={messages_read}"
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if not self._stop_reader.is_set():
|
||||
logger.warning(f"[ACP] Reader error: {e}, pod={self._pod_name}")
|
||||
break
|
||||
finally:
|
||||
# Flush any remaining data in buffer
|
||||
remaining = buffer.strip()
|
||||
if remaining:
|
||||
logger.info(
|
||||
f"[ACP] Flushing buffer on exit: {len(remaining)} bytes, "
|
||||
f"preview={remaining[:200]}"
|
||||
)
|
||||
try:
|
||||
message = json.loads(remaining)
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s-flush"
|
||||
)
|
||||
logger.debug(f"Reader error: {e}")
|
||||
break
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[ACP] Buffer flush failed (not JSON): " f"{remaining[:200]}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[ACP] Reader thread exiting: pod={self._pod_name}, "
|
||||
f"messages_read={messages_read}, "
|
||||
f"empty_read_cycles={empty_read_cycles}"
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the exec session and clean up."""
|
||||
acp_session = (
|
||||
self._state.current_session.session_id
|
||||
if self._state.current_session
|
||||
else "none"
|
||||
)
|
||||
logger.info(
|
||||
f"[ACP] Stopping client: pod={self._pod_name} "
|
||||
f"acp_session={acp_session} prompts_sent={self._prompt_count}"
|
||||
)
|
||||
self._stop_reader.set()
|
||||
|
||||
if self._ws_client is not None:
|
||||
@@ -404,6 +515,105 @@ class ACPExecClient:
|
||||
|
||||
return session_id
|
||||
|
||||
def _list_sessions(self, cwd: str, timeout: float = 10.0) -> list[dict[str, Any]]:
|
||||
"""List available ACP sessions, filtered by working directory.
|
||||
|
||||
Returns:
|
||||
List of session info dicts with keys like 'sessionId', 'cwd', 'title'.
|
||||
Empty list if session/list is not supported or fails.
|
||||
"""
|
||||
try:
|
||||
request_id = self._send_request("session/list", {"cwd": cwd})
|
||||
result = self._wait_for_response(request_id, timeout)
|
||||
sessions = result.get("sessions", [])
|
||||
logger.info(f"[ACP] session/list: {len(sessions)} sessions for cwd={cwd}")
|
||||
return sessions
|
||||
except Exception as e:
|
||||
logger.info(f"[ACP] session/list unavailable: {e}")
|
||||
return []
|
||||
|
||||
def _resume_session(self, session_id: str, cwd: str, timeout: float = 30.0) -> str:
|
||||
"""Resume an existing ACP session.
|
||||
|
||||
Args:
|
||||
session_id: The ACP session ID to resume
|
||||
cwd: Working directory for the session
|
||||
timeout: Timeout for the resume request
|
||||
|
||||
Returns:
|
||||
The session ID
|
||||
|
||||
Raises:
|
||||
RuntimeError: If resume fails
|
||||
"""
|
||||
params = {
|
||||
"sessionId": session_id,
|
||||
"cwd": cwd,
|
||||
"mcpServers": [],
|
||||
}
|
||||
|
||||
request_id = self._send_request("session/resume", params)
|
||||
result = self._wait_for_response(request_id, timeout)
|
||||
|
||||
# The response should contain the session ID
|
||||
resumed_id = result.get("sessionId", session_id)
|
||||
self._state.current_session = ACPSession(session_id=resumed_id, cwd=cwd)
|
||||
|
||||
logger.info(f"[ACP] Resumed session: acp_session={resumed_id} cwd={cwd}")
|
||||
return resumed_id
|
||||
|
||||
def _try_resume_existing_session(self, cwd: str, timeout: float) -> str | None:
|
||||
"""Try to find and resume an existing session for this workspace.
|
||||
|
||||
When multiple API server replicas connect to the same sandbox pod,
|
||||
a previous replica may have already created an ACP session for this
|
||||
workspace. This method discovers and resumes that session so the
|
||||
agent retains conversation context.
|
||||
|
||||
Args:
|
||||
cwd: Working directory to search for sessions
|
||||
timeout: Timeout for ACP requests
|
||||
|
||||
Returns:
|
||||
The resumed session ID, or None if no session could be resumed
|
||||
"""
|
||||
# Check if the agent supports session/list + session/resume
|
||||
session_caps = self._state.agent_capabilities.get("sessionCapabilities", {})
|
||||
supports_list = session_caps.get("list") is not None
|
||||
supports_resume = session_caps.get("resume") is not None
|
||||
|
||||
if not supports_list or not supports_resume:
|
||||
logger.debug("[ACP] Agent does not support session resume")
|
||||
return None
|
||||
|
||||
# List sessions for this workspace directory
|
||||
sessions = self._list_sessions(cwd, timeout=min(timeout, 10.0))
|
||||
if not sessions:
|
||||
return None
|
||||
|
||||
# Pick the most recent session (first in list, assuming sorted)
|
||||
target = sessions[0]
|
||||
target_id = target.get("sessionId")
|
||||
if not target_id:
|
||||
logger.warning(
|
||||
"[ACP-LIFECYCLE] session/list returned session without sessionId"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[ACP] Resuming existing session: acp_session={target_id} "
|
||||
f"(found {len(sessions)})"
|
||||
)
|
||||
|
||||
try:
|
||||
return self._resume_session(target_id, cwd, timeout)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[ACP] session/resume failed for {target_id}: {e}, "
|
||||
f"falling back to session/new"
|
||||
)
|
||||
return None
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
message: str,
|
||||
@@ -423,21 +633,26 @@ class ACPExecClient:
|
||||
|
||||
session_id = self._state.current_session.session_id
|
||||
packet_logger = get_packet_logger()
|
||||
self._prompt_count += 1
|
||||
prompt_num = self._prompt_count
|
||||
|
||||
# Log the start of message processing
|
||||
packet_logger.log_raw(
|
||||
"ACP-SEND-MESSAGE-START-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"pod": self._pod_name,
|
||||
"namespace": self._namespace,
|
||||
"message_preview": (
|
||||
message[:200] + "..." if len(message) > 200 else message
|
||||
),
|
||||
"timeout": timeout,
|
||||
},
|
||||
logger.info(
|
||||
f"[ACP] Prompt #{prompt_num} start: "
|
||||
f"acp_session={session_id} pod={self._pod_name}"
|
||||
)
|
||||
|
||||
# Drain leftover messages from the queue (e.g., session_info_update
|
||||
# that arrived between prompts).
|
||||
drained_count = 0
|
||||
while not self._response_queue.empty():
|
||||
try:
|
||||
self._response_queue.get_nowait()
|
||||
drained_count += 1
|
||||
except Empty:
|
||||
break
|
||||
if drained_count > 0:
|
||||
logger.debug(f"[ACP] Drained {drained_count} stale messages")
|
||||
|
||||
prompt_content = [{"type": "text", "text": message}]
|
||||
params = {
|
||||
"sessionId": session_id,
|
||||
@@ -446,44 +661,97 @@ class ACPExecClient:
|
||||
|
||||
request_id = self._send_request("session/prompt", params)
|
||||
start_time = time.time()
|
||||
last_event_time = time.time() # Track time since last event for keepalive
|
||||
last_event_time = time.time()
|
||||
events_yielded = 0
|
||||
messages_processed = 0
|
||||
keepalive_count = 0
|
||||
completion_reason = "unknown"
|
||||
|
||||
while True:
|
||||
remaining = timeout - (time.time() - start_time)
|
||||
if remaining <= 0:
|
||||
packet_logger.log_raw(
|
||||
"ACP-TIMEOUT-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"elapsed_ms": (time.time() - start_time) * 1000,
|
||||
},
|
||||
completion_reason = "timeout"
|
||||
logger.warning(
|
||||
f"[ACP] Prompt #{prompt_num} timeout: "
|
||||
f"acp_session={session_id} events={events_yielded}"
|
||||
)
|
||||
yield Error(code=-1, message="Timeout waiting for response")
|
||||
break
|
||||
|
||||
try:
|
||||
message_data = self._response_queue.get(timeout=min(remaining, 1.0))
|
||||
last_event_time = time.time() # Reset keepalive timer on event
|
||||
last_event_time = time.time()
|
||||
messages_processed += 1
|
||||
except Empty:
|
||||
# Check if we need to send an SSE keepalive
|
||||
# Check if reader thread is still alive
|
||||
if (
|
||||
self._reader_thread is not None
|
||||
and not self._reader_thread.is_alive()
|
||||
):
|
||||
completion_reason = "reader_thread_dead"
|
||||
# Drain any final messages the reader flushed before dying
|
||||
while not self._response_queue.empty():
|
||||
try:
|
||||
final_msg = self._response_queue.get_nowait()
|
||||
if final_msg.get("id") == request_id:
|
||||
if "error" in final_msg:
|
||||
error_data = final_msg["error"]
|
||||
yield Error(
|
||||
code=error_data.get("code", -1),
|
||||
message=error_data.get(
|
||||
"message", "Unknown error"
|
||||
),
|
||||
)
|
||||
else:
|
||||
result = final_msg.get("result", {})
|
||||
try:
|
||||
yield PromptResponse.model_validate(result)
|
||||
except ValidationError:
|
||||
pass
|
||||
break
|
||||
except Empty:
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"[ACP] Reader thread dead: prompt #{prompt_num} "
|
||||
f"acp_session={session_id} events={events_yielded}"
|
||||
)
|
||||
break
|
||||
|
||||
# Send SSE keepalive if idle
|
||||
idle_time = time.time() - last_event_time
|
||||
if idle_time >= SSE_KEEPALIVE_INTERVAL:
|
||||
packet_logger.log_raw(
|
||||
"SSE-KEEPALIVE-YIELD",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"idle_seconds": idle_time,
|
||||
},
|
||||
)
|
||||
keepalive_count += 1
|
||||
if keepalive_count % 3 == 0:
|
||||
reader_alive = (
|
||||
self._reader_thread is not None
|
||||
and self._reader_thread.is_alive()
|
||||
)
|
||||
elapsed_s = time.time() - start_time
|
||||
logger.info(
|
||||
f"[ACP] Prompt #{prompt_num} waiting: "
|
||||
f"keepalives={keepalive_count} "
|
||||
f"elapsed={elapsed_s:.0f}s "
|
||||
f"events={events_yielded} "
|
||||
f"reader_alive={reader_alive} "
|
||||
f"queue_size={self._response_queue.qsize()}"
|
||||
)
|
||||
yield SSEKeepalive()
|
||||
last_event_time = time.time() # Reset after yielding keepalive
|
||||
last_event_time = time.time()
|
||||
continue
|
||||
|
||||
# Check for response to our prompt request
|
||||
if message_data.get("id") == request_id:
|
||||
# Check for JSON-RPC response to our prompt request.
|
||||
msg_id = message_data.get("id")
|
||||
is_response = "method" not in message_data and (
|
||||
msg_id == request_id
|
||||
or (msg_id is not None and str(msg_id) == str(request_id))
|
||||
)
|
||||
if is_response:
|
||||
completion_reason = "jsonrpc_response"
|
||||
if "error" in message_data:
|
||||
error_data = message_data["error"]
|
||||
completion_reason = "jsonrpc_error"
|
||||
logger.warning(f"[ACP] Prompt #{prompt_num} error: {error_data}")
|
||||
packet_logger.log_jsonrpc_response(
|
||||
request_id, error=error_data, context="k8s"
|
||||
)
|
||||
@@ -498,26 +766,16 @@ class ACPExecClient:
|
||||
)
|
||||
try:
|
||||
prompt_response = PromptResponse.model_validate(result)
|
||||
packet_logger.log_acp_event_yielded(
|
||||
"prompt_response", prompt_response
|
||||
)
|
||||
events_yielded += 1
|
||||
yield prompt_response
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"type": "prompt_response", "error": str(e)},
|
||||
)
|
||||
logger.error(f"[ACP] PromptResponse validation failed: {e}")
|
||||
|
||||
# Log completion summary
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
packet_logger.log_raw(
|
||||
"ACP-SEND-MESSAGE-COMPLETE-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"events_yielded": events_yielded,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
logger.info(
|
||||
f"[ACP] Prompt #{prompt_num} complete: "
|
||||
f"reason={completion_reason} acp_session={session_id} "
|
||||
f"events={events_yielded} elapsed={elapsed_ms:.0f}ms"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -526,25 +784,29 @@ class ACPExecClient:
|
||||
params_data = message_data.get("params", {})
|
||||
update = params_data.get("update", {})
|
||||
|
||||
# Log the notification
|
||||
packet_logger.log_jsonrpc_notification(
|
||||
"session/update",
|
||||
{"update_type": update.get("sessionUpdate")},
|
||||
context="k8s",
|
||||
)
|
||||
|
||||
prompt_complete = False
|
||||
for event in self._process_session_update(update):
|
||||
events_yielded += 1
|
||||
# Log each yielded event
|
||||
event_type = self._get_event_type_name(event)
|
||||
packet_logger.log_acp_event_yielded(event_type, event)
|
||||
yield event
|
||||
if isinstance(event, PromptResponse):
|
||||
prompt_complete = True
|
||||
break
|
||||
|
||||
if prompt_complete:
|
||||
completion_reason = "prompt_response_via_notification"
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[ACP] Prompt #{prompt_num} complete: "
|
||||
f"reason={completion_reason} acp_session={session_id} "
|
||||
f"events={events_yielded} elapsed={elapsed_ms:.0f}ms"
|
||||
)
|
||||
break
|
||||
|
||||
# Handle requests from agent - send error response
|
||||
elif "method" in message_data and "id" in message_data:
|
||||
packet_logger.log_raw(
|
||||
"ACP-UNSUPPORTED-REQUEST-K8S",
|
||||
{"method": message_data["method"], "id": message_data["id"]},
|
||||
logger.debug(
|
||||
f"[ACP] Unsupported agent request: "
|
||||
f"method={message_data['method']}"
|
||||
)
|
||||
self._send_error_response(
|
||||
message_data["id"],
|
||||
@@ -552,113 +814,43 @@ class ACPExecClient:
|
||||
f"Method not supported: {message_data['method']}",
|
||||
)
|
||||
|
||||
def _get_event_type_name(self, event: ACPEvent) -> str:
|
||||
"""Get the type name for an ACP event."""
|
||||
if isinstance(event, AgentMessageChunk):
|
||||
return "agent_message_chunk"
|
||||
elif isinstance(event, AgentThoughtChunk):
|
||||
return "agent_thought_chunk"
|
||||
elif isinstance(event, ToolCallStart):
|
||||
return "tool_call_start"
|
||||
elif isinstance(event, ToolCallProgress):
|
||||
return "tool_call_progress"
|
||||
elif isinstance(event, AgentPlanUpdate):
|
||||
return "agent_plan_update"
|
||||
elif isinstance(event, CurrentModeUpdate):
|
||||
return "current_mode_update"
|
||||
elif isinstance(event, PromptResponse):
|
||||
return "prompt_response"
|
||||
elif isinstance(event, Error):
|
||||
return "error"
|
||||
elif isinstance(event, SSEKeepalive):
|
||||
return "sse_keepalive"
|
||||
return "unknown"
|
||||
else:
|
||||
logger.debug(
|
||||
f"[ACP] Unhandled message: "
|
||||
f"id={message_data.get('id')} "
|
||||
f"method={message_data.get('method')}"
|
||||
)
|
||||
|
||||
def _process_session_update(
|
||||
self, update: dict[str, Any]
|
||||
) -> Generator[ACPEvent, None, None]:
|
||||
"""Process a session/update notification and yield typed ACP schema objects."""
|
||||
update_type = update.get("sessionUpdate")
|
||||
packet_logger = get_packet_logger()
|
||||
|
||||
if update_type == "agent_message_chunk":
|
||||
# Map update types to their ACP schema classes
|
||||
type_map: dict[str, type] = {
|
||||
"agent_message_chunk": AgentMessageChunk,
|
||||
"agent_thought_chunk": AgentThoughtChunk,
|
||||
"tool_call": ToolCallStart,
|
||||
"tool_call_update": ToolCallProgress,
|
||||
"plan": AgentPlanUpdate,
|
||||
"current_mode_update": CurrentModeUpdate,
|
||||
"prompt_response": PromptResponse,
|
||||
}
|
||||
|
||||
model_class = type_map.get(update_type) # type: ignore[arg-type]
|
||||
if model_class is not None:
|
||||
try:
|
||||
yield AgentMessageChunk.model_validate(update)
|
||||
yield model_class.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "agent_thought_chunk":
|
||||
try:
|
||||
yield AgentThoughtChunk.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "user_message_chunk":
|
||||
# Echo of user message - skip but log
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "user_message_chunk"}
|
||||
)
|
||||
|
||||
elif update_type == "tool_call":
|
||||
try:
|
||||
yield ToolCallStart.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "tool_call_update":
|
||||
try:
|
||||
yield ToolCallProgress.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "plan":
|
||||
try:
|
||||
yield AgentPlanUpdate.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "current_mode_update":
|
||||
try:
|
||||
yield CurrentModeUpdate.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "available_commands_update":
|
||||
# Skip command updates
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "available_commands_update"}
|
||||
)
|
||||
|
||||
elif update_type == "session_info_update":
|
||||
# Skip session info updates
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "session_info_update"}
|
||||
)
|
||||
|
||||
else:
|
||||
# Unknown update types are logged
|
||||
packet_logger.log_raw(
|
||||
"ACP-UNKNOWN-UPDATE-TYPE-K8S",
|
||||
{"update_type": update_type, "update": update},
|
||||
)
|
||||
logger.warning(f"[ACP] Validation error for {update_type}: {e}")
|
||||
elif update_type not in (
|
||||
"user_message_chunk",
|
||||
"available_commands_update",
|
||||
"session_info_update",
|
||||
"usage_update",
|
||||
):
|
||||
logger.debug(f"[ACP] Unknown update type: {update_type}")
|
||||
|
||||
def _send_error_response(self, request_id: int, code: int, message: str) -> None:
|
||||
"""Send an error response to an agent request."""
|
||||
|
||||
@@ -50,6 +50,7 @@ from pathlib import Path
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from acp.schema import PromptResponse
|
||||
from kubernetes import client # type: ignore
|
||||
from kubernetes import config
|
||||
from kubernetes.client.rest import ApiException # type: ignore
|
||||
@@ -97,6 +98,10 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# API server pod hostname — used to identify which replica is handling a request.
|
||||
# In K8s, HOSTNAME is set to the pod name (e.g., "api-server-dpgg7").
|
||||
_API_SERVER_HOSTNAME = os.environ.get("HOSTNAME", "unknown")
|
||||
|
||||
# Constants for pod configuration
|
||||
# Note: Next.js ports are dynamically allocated from SANDBOX_NEXTJS_PORT_START to
|
||||
# SANDBOX_NEXTJS_PORT_END range, with one port per session.
|
||||
@@ -353,6 +358,10 @@ class KubernetesSandboxManager(SandboxManager):
|
||||
self._agent_instructions_template_path = build_dir / "AGENTS.template.md"
|
||||
self._skills_path = Path(__file__).parent / "docker" / "skills"
|
||||
|
||||
# Track ACP exec clients in memory - keyed by (sandbox_id, session_id) tuple
|
||||
# Each session within a sandbox has its own ACP client (WebSocket connection)
|
||||
self._acp_clients: dict[tuple[UUID, UUID], ACPExecClient] = {}
|
||||
|
||||
logger.info(
|
||||
f"KubernetesSandboxManager initialized: "
|
||||
f"namespace={self._namespace}, image={self._image}"
|
||||
@@ -1161,6 +1170,20 @@ done
|
||||
Args:
|
||||
sandbox_id: The sandbox ID to terminate
|
||||
"""
|
||||
# Stop all ACP clients for this sandbox (keyed by (sandbox_id, session_id))
|
||||
clients_to_stop = [
|
||||
(key, cl) for key, cl in self._acp_clients.items() if key[0] == sandbox_id
|
||||
]
|
||||
for key, cl in clients_to_stop:
|
||||
try:
|
||||
cl.stop()
|
||||
del self._acp_clients[key]
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to stop ACP client for sandbox {sandbox_id}, "
|
||||
f"session {key[1]}: {e}"
|
||||
)
|
||||
|
||||
# Clean up Kubernetes resources (needs string for pod/service names)
|
||||
self._cleanup_kubernetes_resources(str(sandbox_id))
|
||||
|
||||
@@ -1403,6 +1426,18 @@ echo "Session workspace setup complete"
|
||||
nextjs_port: Optional port where Next.js server is running (unused in K8s,
|
||||
we use PID file instead)
|
||||
"""
|
||||
# Stop ACP client for this session
|
||||
client_key = (sandbox_id, session_id)
|
||||
acp_client = self._acp_clients.pop(client_key, None)
|
||||
if acp_client:
|
||||
try:
|
||||
acp_client.stop()
|
||||
logger.debug(f"Stopped ACP client for session {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to stop ACP client for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
|
||||
@@ -1830,34 +1865,77 @@ echo "Session config regeneration complete"
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
|
||||
# Log ACP client creation
|
||||
packet_logger.log_acp_client_start(
|
||||
sandbox_id, session_id, session_path, context="k8s"
|
||||
)
|
||||
# Get or create ACP client for this session
|
||||
client_key = (sandbox_id, session_id)
|
||||
client = self._acp_clients.get(client_key)
|
||||
|
||||
exec_client = ACPExecClient(
|
||||
pod_name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
)
|
||||
if client is None or not client.is_running:
|
||||
# Clean up stale client if it exists but is no longer running
|
||||
if client is not None:
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Cleaning up stale client: "
|
||||
f"session={session_id} acp_session={client.session_id}"
|
||||
)
|
||||
try:
|
||||
client.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Creating ACP client: "
|
||||
f"session={session_id} pod={pod_name} "
|
||||
f"api_pod={_API_SERVER_HOSTNAME}"
|
||||
)
|
||||
|
||||
# Create and start ACP client for this session.
|
||||
# start() will try to resume an existing session from the pod
|
||||
# (handles multi-replica: another API pod may have created
|
||||
# the session earlier).
|
||||
client = ACPExecClient(
|
||||
pod_name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
)
|
||||
client.start(cwd=session_path)
|
||||
self._acp_clients[client_key] = client
|
||||
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] ACP client ready: "
|
||||
f"session={session_id} acp_session={client.session_id} "
|
||||
f"api_pod={_API_SERVER_HOSTNAME}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Reusing cached client: "
|
||||
f"session={session_id} acp_session={client.session_id} "
|
||||
f"api_pod={_API_SERVER_HOSTNAME}"
|
||||
)
|
||||
|
||||
# Log the send_message call at sandbox manager level
|
||||
packet_logger.log_session_start(session_id, sandbox_id, message)
|
||||
|
||||
events_count = 0
|
||||
got_prompt_response = False
|
||||
try:
|
||||
exec_client.start(cwd=session_path)
|
||||
for event in exec_client.send_message(message):
|
||||
for event in client.send_message(message):
|
||||
events_count += 1
|
||||
if isinstance(event, PromptResponse):
|
||||
got_prompt_response = True
|
||||
yield event
|
||||
|
||||
# Log successful completion
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] send_message completed: "
|
||||
f"session={session_id} events={events_count} "
|
||||
f"got_prompt_response={got_prompt_response}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id, success=True, events_count=events_count
|
||||
)
|
||||
except GeneratorExit:
|
||||
# Generator was closed by consumer (client disconnect, timeout, broken pipe)
|
||||
# This is the most common failure mode for SSE streaming
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] GeneratorExit: session={session_id} "
|
||||
f"events={events_count}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
@@ -1866,7 +1944,10 @@ echo "Session config regeneration complete"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log failure from normal exceptions
|
||||
logger.error(
|
||||
f"[SANDBOX-ACP] Exception: session={session_id} "
|
||||
f"events={events_count} error={e}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
@@ -1875,19 +1956,16 @@ echo "Session config regeneration complete"
|
||||
)
|
||||
raise
|
||||
except BaseException as e:
|
||||
# Log failure from other base exceptions (SystemExit, KeyboardInterrupt, etc.)
|
||||
exception_type = type(e).__name__
|
||||
logger.error(
|
||||
f"[SANDBOX-ACP] {type(e).__name__}: session={session_id} " f"error={e}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
error=f"{exception_type}: {str(e) if str(e) else 'System-level interruption'}",
|
||||
error=f"{type(e).__name__}: {str(e) if str(e) else 'System-level interruption'}",
|
||||
events_count=events_count,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
exec_client.stop()
|
||||
# Log client stop
|
||||
packet_logger.log_acp_client_stop(sandbox_id, session_id, context="k8s")
|
||||
|
||||
def list_directory(
|
||||
self, sandbox_id: UUID, session_id: UUID, path: str
|
||||
|
||||
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""Unit tests for ACPExecClient caching behavior in KubernetesSandboxManager.
|
||||
|
||||
These tests verify that the KubernetesSandboxManager correctly caches
|
||||
ACPExecClient instances per (sandbox_id, session_id) pair, reuses them
|
||||
across send_message calls, replaces dead clients, and cleans them up
|
||||
on terminate/cleanup.
|
||||
|
||||
All external dependencies (K8s, WebSockets, packet logging) are mocked.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The fully-qualified path to the module under test, used for patching
|
||||
_K8S_MODULE = "onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager"
|
||||
_ACP_CLIENT_CLASS = f"{_K8S_MODULE}.ACPExecClient"
|
||||
_GET_PACKET_LOGGER = f"{_K8S_MODULE}.get_packet_logger"
|
||||
|
||||
|
||||
def _make_mock_event() -> MagicMock:
|
||||
"""Create a mock ACP event."""
|
||||
return MagicMock(name="mock_acp_event")
|
||||
|
||||
|
||||
def _make_mock_client(is_running: bool = True) -> MagicMock:
|
||||
"""Create a mock ACPExecClient with configurable is_running property."""
|
||||
mock_client = MagicMock()
|
||||
type(mock_client).is_running = property(lambda _self: is_running)
|
||||
mock_client.start.return_value = "mock-session-id"
|
||||
mock_event = _make_mock_event()
|
||||
mock_client.send_message.return_value = iter([mock_event])
|
||||
mock_client.stop.return_value = None
|
||||
return mock_client
|
||||
|
||||
|
||||
def _drain_generator(gen: Generator[Any, None, None]) -> list[Any]:
|
||||
"""Consume a generator and return all yielded values as a list."""
|
||||
return list(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture: fresh KubernetesSandboxManager instance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def manager() -> Generator[Any, None, None]:
|
||||
"""Create a fresh KubernetesSandboxManager instance with all externals mocked.
|
||||
|
||||
This fixture:
|
||||
1. Resets the singleton _instance so each test gets a fresh manager
|
||||
2. Mocks kubernetes.config and kubernetes.client to prevent real K8s calls
|
||||
3. Mocks get_packet_logger to prevent logging side effects
|
||||
"""
|
||||
# Import here so patches are in effect when the class loads
|
||||
with (
|
||||
patch(f"{_K8S_MODULE}.config") as _mock_config,
|
||||
patch(f"{_K8S_MODULE}.client") as _mock_k8s_client,
|
||||
patch(f"{_K8S_MODULE}.k8s_stream"),
|
||||
patch(_GET_PACKET_LOGGER) as mock_get_logger,
|
||||
):
|
||||
# Set up the mock packet logger
|
||||
mock_packet_logger = MagicMock()
|
||||
mock_get_logger.return_value = mock_packet_logger
|
||||
|
||||
# Make config.load_incluster_config succeed (no-op)
|
||||
_mock_config.load_incluster_config.return_value = None
|
||||
_mock_config.ConfigException = Exception
|
||||
|
||||
# Make client constructors return mocks
|
||||
_mock_k8s_client.ApiClient.return_value = MagicMock()
|
||||
_mock_k8s_client.CoreV1Api.return_value = MagicMock()
|
||||
_mock_k8s_client.BatchV1Api.return_value = MagicMock()
|
||||
_mock_k8s_client.NetworkingV1Api.return_value = MagicMock()
|
||||
|
||||
# Reset singleton before importing
|
||||
from onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager import (
|
||||
KubernetesSandboxManager,
|
||||
)
|
||||
|
||||
KubernetesSandboxManager._instance = None
|
||||
|
||||
mgr = KubernetesSandboxManager()
|
||||
|
||||
# Ensure the _acp_clients dict exists (it should be initialized by
|
||||
# the caching implementation)
|
||||
if not hasattr(mgr, "_acp_clients"):
|
||||
mgr._acp_clients = {}
|
||||
|
||||
yield mgr
|
||||
|
||||
# Reset singleton after test
|
||||
KubernetesSandboxManager._instance = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_send_message_creates_client_on_first_call(manager: Any) -> None:
|
||||
"""First call to send_message() should create a new ACPExecClient,
|
||||
call start(), cache it, and yield events from send_message()."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
message = "hello world"
|
||||
|
||||
mock_event = _make_mock_event()
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
mock_client.send_message.return_value = iter([mock_event])
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
|
||||
events = _drain_generator(manager.send_message(sandbox_id, session_id, message))
|
||||
|
||||
# Verify client was constructed
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# Verify start() was called with the correct session path
|
||||
expected_cwd = f"/workspace/sessions/{session_id}"
|
||||
mock_client.start.assert_called_once_with(cwd=expected_cwd)
|
||||
|
||||
# Verify send_message was called on the client
|
||||
mock_client.send_message.assert_called_once_with(message)
|
||||
|
||||
# Verify we got the event
|
||||
assert len(events) >= 1
|
||||
# Find our mock event (filter out any SSEKeepalive or similar)
|
||||
assert mock_event in events
|
||||
|
||||
# Verify client was cached
|
||||
client_key = (sandbox_id, session_id)
|
||||
assert client_key in manager._acp_clients
|
||||
assert manager._acp_clients[client_key] is mock_client
|
||||
|
||||
|
||||
def test_send_message_reuses_cached_client(manager: Any) -> None:
|
||||
"""Second call with the same (sandbox_id, session_id) should NOT create
|
||||
a new client. Should reuse the cached one."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
message_1 = "first message"
|
||||
message_2 = "second message"
|
||||
|
||||
mock_event_1 = _make_mock_event()
|
||||
mock_event_2 = _make_mock_event()
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
|
||||
# send_message returns different events for each call
|
||||
mock_client.send_message.side_effect = [
|
||||
iter([mock_event_1]),
|
||||
iter([mock_event_2]),
|
||||
]
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
|
||||
events_1 = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id, message_1)
|
||||
)
|
||||
events_2 = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id, message_2)
|
||||
)
|
||||
|
||||
# Constructor called only ONCE (on first send_message)
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# start() called only once
|
||||
mock_client.start.assert_called_once()
|
||||
|
||||
# send_message called twice with different messages
|
||||
assert mock_client.send_message.call_count == 2
|
||||
mock_client.send_message.assert_any_call(message_1)
|
||||
mock_client.send_message.assert_any_call(message_2)
|
||||
|
||||
# Both calls yielded events
|
||||
assert mock_event_1 in events_1
|
||||
assert mock_event_2 in events_2
|
||||
|
||||
|
||||
def test_send_message_replaces_dead_client(manager: Any) -> None:
|
||||
"""If cached client has is_running == False, should create a new one,
|
||||
start it, and cache the replacement."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
message = "test message"
|
||||
|
||||
# Create a dead client (is_running = False) and place it in the cache
|
||||
dead_client = _make_mock_client(is_running=False)
|
||||
client_key = (sandbox_id, session_id)
|
||||
manager._acp_clients[client_key] = dead_client
|
||||
|
||||
# Create the replacement client
|
||||
new_event = _make_mock_event()
|
||||
new_client = _make_mock_client(is_running=True)
|
||||
new_client.send_message.return_value = iter([new_event])
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=new_client) as MockClass:
|
||||
events = _drain_generator(manager.send_message(sandbox_id, session_id, message))
|
||||
|
||||
# A new client was constructed (the dead one was replaced)
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# New client was started and used
|
||||
new_client.start.assert_called_once()
|
||||
new_client.send_message.assert_called_once_with(message)
|
||||
|
||||
# Cache now holds the new client
|
||||
assert manager._acp_clients[client_key] is new_client
|
||||
|
||||
# Events from new client were yielded
|
||||
assert new_event in events
|
||||
|
||||
|
||||
def test_send_message_different_sessions_get_different_clients(
|
||||
manager: Any,
|
||||
) -> None:
|
||||
"""Two calls with different session_id values should create two
|
||||
separate clients, each cached under its own key."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id_a: UUID = uuid4()
|
||||
session_id_b: UUID = uuid4()
|
||||
message = "test"
|
||||
|
||||
mock_client_a = _make_mock_client(is_running=True)
|
||||
mock_client_b = _make_mock_client(is_running=True)
|
||||
|
||||
mock_event_a = _make_mock_event()
|
||||
mock_event_b = _make_mock_event()
|
||||
mock_client_a.send_message.return_value = iter([mock_event_a])
|
||||
mock_client_b.send_message.return_value = iter([mock_event_b])
|
||||
|
||||
with patch(
|
||||
_ACP_CLIENT_CLASS, side_effect=[mock_client_a, mock_client_b]
|
||||
) as MockClass:
|
||||
events_a = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id_a, message)
|
||||
)
|
||||
events_b = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id_b, message)
|
||||
)
|
||||
|
||||
# Two separate clients were constructed
|
||||
assert MockClass.call_count == 2
|
||||
|
||||
# Both were started
|
||||
mock_client_a.start.assert_called_once()
|
||||
mock_client_b.start.assert_called_once()
|
||||
|
||||
# Each is cached under a different key
|
||||
assert manager._acp_clients[(sandbox_id, session_id_a)] is mock_client_a
|
||||
assert manager._acp_clients[(sandbox_id, session_id_b)] is mock_client_b
|
||||
|
||||
# Events from each client are correct
|
||||
assert mock_event_a in events_a
|
||||
assert mock_event_b in events_b
|
||||
|
||||
|
||||
def test_terminate_stops_all_sandbox_clients(manager: Any) -> None:
|
||||
"""terminate(sandbox_id) should stop all cached clients for that
|
||||
sandbox and remove them from the cache."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id_1: UUID = uuid4()
|
||||
session_id_2: UUID = uuid4()
|
||||
|
||||
client_1 = _make_mock_client(is_running=True)
|
||||
client_2 = _make_mock_client(is_running=True)
|
||||
|
||||
manager._acp_clients[(sandbox_id, session_id_1)] = client_1
|
||||
manager._acp_clients[(sandbox_id, session_id_2)] = client_2
|
||||
|
||||
# Mock _cleanup_kubernetes_resources to prevent actual K8s calls
|
||||
with patch.object(manager, "_cleanup_kubernetes_resources"):
|
||||
manager.terminate(sandbox_id)
|
||||
|
||||
# Both clients should have been stopped
|
||||
client_1.stop.assert_called_once()
|
||||
client_2.stop.assert_called_once()
|
||||
|
||||
# Both should be removed from cache
|
||||
assert (sandbox_id, session_id_1) not in manager._acp_clients
|
||||
assert (sandbox_id, session_id_2) not in manager._acp_clients
|
||||
|
||||
|
||||
def test_terminate_leaves_other_sandbox_clients(manager: Any) -> None:
|
||||
"""terminate(sandbox_id_A) should NOT affect clients cached for
|
||||
sandbox_id_B."""
|
||||
sandbox_id_a: UUID = uuid4()
|
||||
sandbox_id_b: UUID = uuid4()
|
||||
session_id_a: UUID = uuid4()
|
||||
session_id_b: UUID = uuid4()
|
||||
|
||||
client_a = _make_mock_client(is_running=True)
|
||||
client_b = _make_mock_client(is_running=True)
|
||||
|
||||
manager._acp_clients[(sandbox_id_a, session_id_a)] = client_a
|
||||
manager._acp_clients[(sandbox_id_b, session_id_b)] = client_b
|
||||
|
||||
# Terminate only sandbox A
|
||||
with patch.object(manager, "_cleanup_kubernetes_resources"):
|
||||
manager.terminate(sandbox_id_a)
|
||||
|
||||
# Client A stopped and removed
|
||||
client_a.stop.assert_called_once()
|
||||
assert (sandbox_id_a, session_id_a) not in manager._acp_clients
|
||||
|
||||
# Client B untouched
|
||||
client_b.stop.assert_not_called()
|
||||
assert (sandbox_id_b, session_id_b) in manager._acp_clients
|
||||
assert manager._acp_clients[(sandbox_id_b, session_id_b)] is client_b
|
||||
|
||||
|
||||
def test_cleanup_session_stops_session_client(manager: Any) -> None:
|
||||
"""cleanup_session_workspace(sandbox_id, session_id) should stop and
|
||||
remove the specific session's client from the cache."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
cached_client = _make_mock_client(is_running=True)
|
||||
manager._acp_clients[(sandbox_id, session_id)] = cached_client
|
||||
|
||||
# Mock the k8s exec call that runs the cleanup script
|
||||
with patch.object(manager, "_stream_core_api") as mock_stream_api:
|
||||
mock_stream_api.connect_get_namespaced_pod_exec = MagicMock()
|
||||
with patch(f"{_K8S_MODULE}.k8s_stream", return_value="cleanup ok"):
|
||||
manager.cleanup_session_workspace(sandbox_id, session_id)
|
||||
|
||||
# Client should have been stopped
|
||||
cached_client.stop.assert_called_once()
|
||||
|
||||
# Client should be removed from the cache
|
||||
assert (sandbox_id, session_id) not in manager._acp_clients
|
||||
|
||||
|
||||
def test_cleanup_session_handles_no_cached_client(manager: Any) -> None:
|
||||
"""cleanup_session_workspace() should not error when there's no cached
|
||||
client for that session."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
# Ensure no client is cached for this pair
|
||||
assert (sandbox_id, session_id) not in manager._acp_clients
|
||||
|
||||
# Mock the k8s exec call that runs the cleanup script
|
||||
with patch.object(manager, "_stream_core_api") as mock_stream_api:
|
||||
mock_stream_api.connect_get_namespaced_pod_exec = MagicMock()
|
||||
with patch(f"{_K8S_MODULE}.k8s_stream", return_value="cleanup ok"):
|
||||
# This should NOT raise
|
||||
manager.cleanup_session_workspace(sandbox_id, session_id)
|
||||
|
||||
# Cache is still empty for this key
|
||||
assert (sandbox_id, session_id) not in manager._acp_clients
|
||||
Reference in New Issue
Block a user