mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-24 11:15:47 +00:00
Compare commits
14 Commits
ci_script
...
experiment
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
495e2f7c52 | ||
|
|
189cc5bc3c | ||
|
|
23ec38662e | ||
|
|
57d741c5b3 | ||
|
|
021af74739 | ||
|
|
adaca6a353 | ||
|
|
4cd07d7bbc | ||
|
|
bf3c98142d | ||
|
|
f023599618 | ||
|
|
c55cb899f7 | ||
|
|
9b8a6e60b7 | ||
|
|
dd9d201b51 | ||
|
|
c545819aa6 | ||
|
|
960ee228bf |
@@ -4,8 +4,9 @@ This client runs `opencode acp` directly in the sandbox pod via kubernetes exec,
|
||||
using stdin/stdout for JSON-RPC communication. This bypasses the HTTP server
|
||||
and uses the native ACP subprocess protocol.
|
||||
|
||||
This module includes comprehensive logging for debugging ACP communication.
|
||||
Enable logging by setting LOG_LEVEL=DEBUG or BUILD_PACKET_LOGGING=true.
|
||||
When multiple API server replicas share the same sandbox pod, this client
|
||||
uses ACP session resumption (session/list + session/resume) to maintain
|
||||
conversation context across replicas.
|
||||
|
||||
Usage:
|
||||
client = ACPExecClient(
|
||||
@@ -100,7 +101,7 @@ class ACPClientState:
|
||||
"""Internal state for the ACP client."""
|
||||
|
||||
initialized: bool = False
|
||||
current_session: ACPSession | None = None
|
||||
sessions: dict[str, ACPSession] = field(default_factory=dict)
|
||||
next_request_id: int = 0
|
||||
agent_capabilities: dict[str, Any] = field(default_factory=dict)
|
||||
agent_info: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -155,16 +156,16 @@ class ACPExecClient:
|
||||
self._k8s_client = client.CoreV1Api()
|
||||
return self._k8s_client
|
||||
|
||||
def start(self, cwd: str = "/workspace", timeout: float = 30.0) -> str:
|
||||
"""Start the agent process via exec and initialize a session.
|
||||
def start(self, cwd: str = "/workspace", timeout: float = 30.0) -> None:
|
||||
"""Start the agent process via exec and initialize the ACP connection.
|
||||
|
||||
Only performs the ACP `initialize` handshake. Sessions are created
|
||||
separately via `create_session()` or `resume_session()`.
|
||||
|
||||
Args:
|
||||
cwd: Working directory for the agent
|
||||
cwd: Working directory for the `opencode acp` process
|
||||
timeout: Timeout for initialization
|
||||
|
||||
Returns:
|
||||
The session ID
|
||||
|
||||
Raises:
|
||||
RuntimeError: If startup fails
|
||||
"""
|
||||
@@ -176,6 +177,8 @@ class ACPExecClient:
|
||||
# Start opencode acp via exec
|
||||
exec_command = ["opencode", "acp", "--cwd", cwd]
|
||||
|
||||
logger.info(f"[ACP] Starting client: pod={self._pod_name} cwd={cwd}")
|
||||
|
||||
try:
|
||||
self._ws_client = k8s_stream(
|
||||
k8s.connect_get_namespaced_pod_exec,
|
||||
@@ -201,15 +204,13 @@ class ACPExecClient:
|
||||
# Give process a moment to start
|
||||
time.sleep(0.5)
|
||||
|
||||
# Initialize ACP connection
|
||||
# Initialize ACP connection (no session creation)
|
||||
self._initialize(timeout=timeout)
|
||||
|
||||
# Create session
|
||||
session_id = self._create_session(cwd=cwd, timeout=timeout)
|
||||
|
||||
return session_id
|
||||
logger.info(f"[ACP] Client started: pod={self._pod_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ACP] Client start failed: pod={self._pod_name} error={e}")
|
||||
self.stop()
|
||||
raise RuntimeError(f"Failed to start ACP exec client: {e}") from e
|
||||
|
||||
@@ -217,63 +218,94 @@ class ACPExecClient:
|
||||
"""Background thread to read responses from the exec stream."""
|
||||
buffer = ""
|
||||
packet_logger = get_packet_logger()
|
||||
# Stale cycle counter: when the buffer has unterminated content
|
||||
# and no new data arrives, we try to parse after a few cycles.
|
||||
buffer_stale_cycles = 0
|
||||
|
||||
while not self._stop_reader.is_set():
|
||||
if self._ws_client is None:
|
||||
break
|
||||
|
||||
try:
|
||||
if self._ws_client.is_open():
|
||||
# Read available data
|
||||
self._ws_client.update(timeout=0.1)
|
||||
|
||||
# Read stdout (channel 1)
|
||||
data = self._ws_client.read_stdout(timeout=0.1)
|
||||
if data:
|
||||
buffer += data
|
||||
|
||||
# Process complete lines
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
message = json.loads(line)
|
||||
# Log the raw incoming message
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
packet_logger.log_raw(
|
||||
"JSONRPC-PARSE-ERROR-K8S",
|
||||
{
|
||||
"raw_line": line[:500],
|
||||
"error": "JSON decode failed",
|
||||
},
|
||||
)
|
||||
logger.warning(
|
||||
f"Invalid JSON from agent: {line[:100]}"
|
||||
)
|
||||
|
||||
else:
|
||||
packet_logger.log_raw(
|
||||
"K8S-WEBSOCKET-CLOSED",
|
||||
{"pod": self._pod_name, "namespace": self._namespace},
|
||||
)
|
||||
try:
|
||||
while not self._stop_reader.is_set():
|
||||
if self._ws_client is None:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if not self._stop_reader.is_set():
|
||||
packet_logger.log_raw(
|
||||
"K8S-READER-ERROR",
|
||||
{"error": str(e), "pod": self._pod_name},
|
||||
try:
|
||||
if self._ws_client.is_open():
|
||||
self._ws_client.update(timeout=0.1)
|
||||
|
||||
# Read stderr - log any agent errors
|
||||
stderr_data = self._ws_client.read_stderr(timeout=0.01)
|
||||
if stderr_data:
|
||||
logger.warning(
|
||||
f"[ACP] stderr pod={self._pod_name}: "
|
||||
f"{stderr_data.strip()[:500]}"
|
||||
)
|
||||
|
||||
# Read stdout
|
||||
data = self._ws_client.read_stdout(timeout=0.1)
|
||||
if data:
|
||||
buffer += data
|
||||
buffer_stale_cycles = 0
|
||||
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
message = json.loads(line)
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[ACP] Invalid JSON from agent: "
|
||||
f"{line[:100]}"
|
||||
)
|
||||
elif buffer.strip():
|
||||
# No new data but buffer has unterminated content.
|
||||
# After a few cycles (~0.5s), try to parse it —
|
||||
# the agent may have omitted the trailing newline.
|
||||
buffer_stale_cycles += 1
|
||||
if buffer_stale_cycles >= 3:
|
||||
try:
|
||||
message = json.loads(buffer.strip())
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s-unterminated"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
buffer = ""
|
||||
buffer_stale_cycles = 0
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
else:
|
||||
logger.warning(f"[ACP] WebSocket closed: pod={self._pod_name}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if not self._stop_reader.is_set():
|
||||
logger.warning(f"[ACP] Reader error: {e}, pod={self._pod_name}")
|
||||
break
|
||||
finally:
|
||||
# Flush any remaining data in buffer
|
||||
remaining = buffer.strip()
|
||||
if remaining:
|
||||
try:
|
||||
message = json.loads(remaining)
|
||||
packet_logger.log_jsonrpc_raw_message(
|
||||
"IN", message, context="k8s-flush"
|
||||
)
|
||||
self._response_queue.put(message)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"[ACP] Buffer flush failed (not JSON): {remaining[:200]}"
|
||||
)
|
||||
logger.debug(f"Reader error: {e}")
|
||||
break
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the exec session and clean up."""
|
||||
session_ids = list(self._state.sessions.keys())
|
||||
logger.info(
|
||||
f"[ACP] Stopping client: pod={self._pod_name} " f"sessions={session_ids}"
|
||||
)
|
||||
self._stop_reader.set()
|
||||
|
||||
if self._ws_client is not None:
|
||||
@@ -400,42 +432,196 @@ class ACPExecClient:
|
||||
if not session_id:
|
||||
raise RuntimeError("No session ID returned from session/new")
|
||||
|
||||
self._state.current_session = ACPSession(session_id=session_id, cwd=cwd)
|
||||
self._state.sessions[session_id] = ACPSession(session_id=session_id, cwd=cwd)
|
||||
logger.info(f"[ACP] Created session: acp_session={session_id} cwd={cwd}")
|
||||
|
||||
return session_id
|
||||
|
||||
def _list_sessions(self, cwd: str, timeout: float = 10.0) -> list[dict[str, Any]]:
|
||||
"""List available ACP sessions, filtered by working directory.
|
||||
|
||||
Returns:
|
||||
List of session info dicts with keys like 'sessionId', 'cwd', 'title'.
|
||||
Empty list if session/list is not supported or fails.
|
||||
"""
|
||||
try:
|
||||
request_id = self._send_request("session/list", {"cwd": cwd})
|
||||
result = self._wait_for_response(request_id, timeout)
|
||||
sessions = result.get("sessions", [])
|
||||
logger.info(f"[ACP] session/list: {len(sessions)} sessions for cwd={cwd}")
|
||||
return sessions
|
||||
except Exception as e:
|
||||
logger.info(f"[ACP] session/list unavailable: {e}")
|
||||
return []
|
||||
|
||||
def _resume_session(self, session_id: str, cwd: str, timeout: float = 30.0) -> str:
|
||||
"""Resume an existing ACP session.
|
||||
|
||||
Args:
|
||||
session_id: The ACP session ID to resume
|
||||
cwd: Working directory for the session
|
||||
timeout: Timeout for the resume request
|
||||
|
||||
Returns:
|
||||
The session ID
|
||||
|
||||
Raises:
|
||||
RuntimeError: If resume fails
|
||||
"""
|
||||
params = {
|
||||
"sessionId": session_id,
|
||||
"cwd": cwd,
|
||||
"mcpServers": [],
|
||||
}
|
||||
|
||||
request_id = self._send_request("session/resume", params)
|
||||
result = self._wait_for_response(request_id, timeout)
|
||||
|
||||
# The response should contain the session ID
|
||||
resumed_id = result.get("sessionId", session_id)
|
||||
self._state.sessions[resumed_id] = ACPSession(session_id=resumed_id, cwd=cwd)
|
||||
|
||||
logger.info(f"[ACP] Resumed session: acp_session={resumed_id} cwd={cwd}")
|
||||
return resumed_id
|
||||
|
||||
def _try_resume_existing_session(self, cwd: str, timeout: float) -> str | None:
|
||||
"""Try to find and resume an existing session for this workspace.
|
||||
|
||||
When multiple API server replicas connect to the same sandbox pod,
|
||||
a previous replica may have already created an ACP session for this
|
||||
workspace. This method discovers and resumes that session so the
|
||||
agent retains conversation context.
|
||||
|
||||
Args:
|
||||
cwd: Working directory to search for sessions
|
||||
timeout: Timeout for ACP requests
|
||||
|
||||
Returns:
|
||||
The resumed session ID, or None if no session could be resumed
|
||||
"""
|
||||
# Check if the agent supports session/list + session/resume
|
||||
session_caps = self._state.agent_capabilities.get("sessionCapabilities", {})
|
||||
supports_list = session_caps.get("list") is not None
|
||||
supports_resume = session_caps.get("resume") is not None
|
||||
|
||||
if not supports_list or not supports_resume:
|
||||
logger.debug("[ACP] Agent does not support session resume")
|
||||
return None
|
||||
|
||||
# List sessions for this workspace directory
|
||||
sessions = self._list_sessions(cwd, timeout=min(timeout, 10.0))
|
||||
if not sessions:
|
||||
return None
|
||||
|
||||
# Pick the most recent session (first in list, assuming sorted)
|
||||
target = sessions[0]
|
||||
target_id = target.get("sessionId")
|
||||
if not target_id:
|
||||
logger.warning("[ACP] session/list returned session without sessionId")
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[ACP] Resuming existing session: acp_session={target_id} "
|
||||
f"(found {len(sessions)})"
|
||||
)
|
||||
|
||||
try:
|
||||
return self._resume_session(target_id, cwd, timeout)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[ACP] session/resume failed for {target_id}: {e}, "
|
||||
f"falling back to session/new"
|
||||
)
|
||||
return None
|
||||
|
||||
def create_session(self, cwd: str, timeout: float = 30.0) -> str:
|
||||
"""Create a new ACP session on this connection.
|
||||
|
||||
Args:
|
||||
cwd: Working directory for the session
|
||||
timeout: Timeout for the request
|
||||
|
||||
Returns:
|
||||
The ACP session ID
|
||||
"""
|
||||
if not self._state.initialized:
|
||||
raise RuntimeError("Client not initialized. Call start() first.")
|
||||
return self._create_session(cwd=cwd, timeout=timeout)
|
||||
|
||||
def resume_session(self, session_id: str, cwd: str, timeout: float = 30.0) -> str:
|
||||
"""Resume an existing ACP session on this connection.
|
||||
|
||||
Args:
|
||||
session_id: The ACP session ID to resume
|
||||
cwd: Working directory for the session
|
||||
timeout: Timeout for the request
|
||||
|
||||
Returns:
|
||||
The ACP session ID
|
||||
"""
|
||||
if not self._state.initialized:
|
||||
raise RuntimeError("Client not initialized. Call start() first.")
|
||||
return self._resume_session(session_id=session_id, cwd=cwd, timeout=timeout)
|
||||
|
||||
def get_or_create_session(self, cwd: str, timeout: float = 30.0) -> str:
|
||||
"""Get an existing session for this cwd, or create/resume one.
|
||||
|
||||
Tries in order:
|
||||
1. Return an already-tracked session for this cwd
|
||||
2. Resume an existing session from opencode's storage (multi-replica)
|
||||
3. Create a new session
|
||||
|
||||
Args:
|
||||
cwd: Working directory for the session
|
||||
timeout: Timeout for ACP requests
|
||||
|
||||
Returns:
|
||||
The ACP session ID
|
||||
"""
|
||||
if not self._state.initialized:
|
||||
raise RuntimeError("Client not initialized. Call start() first.")
|
||||
|
||||
# Check if we already have a session for this cwd
|
||||
for sid, session in self._state.sessions.items():
|
||||
if session.cwd == cwd:
|
||||
logger.info(
|
||||
f"[ACP] Reusing existing session: " f"acp_session={sid} cwd={cwd}"
|
||||
)
|
||||
return sid
|
||||
|
||||
# Try to resume from opencode's persisted storage
|
||||
resumed_id = self._try_resume_existing_session(cwd, timeout)
|
||||
if resumed_id:
|
||||
return resumed_id
|
||||
|
||||
# Create a new session
|
||||
return self._create_session(cwd=cwd, timeout=timeout)
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str,
|
||||
timeout: float = ACP_MESSAGE_TIMEOUT,
|
||||
) -> Generator[ACPEvent, None, None]:
|
||||
"""Send a message and stream response events.
|
||||
"""Send a message to a specific session and stream response events.
|
||||
|
||||
Args:
|
||||
message: The message content to send
|
||||
session_id: The ACP session ID to send the message to
|
||||
timeout: Maximum time to wait for complete response (defaults to ACP_MESSAGE_TIMEOUT env var)
|
||||
|
||||
Yields:
|
||||
Typed ACP schema event objects
|
||||
"""
|
||||
if self._state.current_session is None:
|
||||
raise RuntimeError("No active session. Call start() first.")
|
||||
|
||||
session_id = self._state.current_session.session_id
|
||||
if session_id not in self._state.sessions:
|
||||
raise RuntimeError(
|
||||
f"Unknown session {session_id}. "
|
||||
f"Known sessions: {list(self._state.sessions.keys())}"
|
||||
)
|
||||
packet_logger = get_packet_logger()
|
||||
|
||||
# Log the start of message processing
|
||||
packet_logger.log_raw(
|
||||
"ACP-SEND-MESSAGE-START-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"pod": self._pod_name,
|
||||
"namespace": self._namespace,
|
||||
"message_preview": (
|
||||
message[:200] + "..." if len(message) > 200 else message
|
||||
),
|
||||
"timeout": timeout,
|
||||
},
|
||||
logger.info(
|
||||
f"[ACP] Sending prompt: " f"acp_session={session_id} pod={self._pod_name}"
|
||||
)
|
||||
|
||||
prompt_content = [{"type": "text", "text": message}]
|
||||
@@ -446,44 +632,88 @@ class ACPExecClient:
|
||||
|
||||
request_id = self._send_request("session/prompt", params)
|
||||
start_time = time.time()
|
||||
last_event_time = time.time() # Track time since last event for keepalive
|
||||
last_event_time = time.time()
|
||||
events_yielded = 0
|
||||
keepalive_count = 0
|
||||
completion_reason = "unknown"
|
||||
|
||||
while True:
|
||||
remaining = timeout - (time.time() - start_time)
|
||||
if remaining <= 0:
|
||||
packet_logger.log_raw(
|
||||
"ACP-TIMEOUT-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"elapsed_ms": (time.time() - start_time) * 1000,
|
||||
},
|
||||
completion_reason = "timeout"
|
||||
logger.warning(
|
||||
f"[ACP] Prompt timeout: "
|
||||
f"acp_session={session_id} events={events_yielded}, "
|
||||
f"sending session/cancel"
|
||||
)
|
||||
try:
|
||||
self.cancel(session_id=session_id)
|
||||
except Exception as cancel_err:
|
||||
logger.warning(
|
||||
f"[ACP] session/cancel failed on timeout: {cancel_err}"
|
||||
)
|
||||
yield Error(code=-1, message="Timeout waiting for response")
|
||||
break
|
||||
|
||||
try:
|
||||
message_data = self._response_queue.get(timeout=min(remaining, 1.0))
|
||||
last_event_time = time.time() # Reset keepalive timer on event
|
||||
last_event_time = time.time()
|
||||
except Empty:
|
||||
# Check if we need to send an SSE keepalive
|
||||
# Check if reader thread is still alive
|
||||
if (
|
||||
self._reader_thread is not None
|
||||
and not self._reader_thread.is_alive()
|
||||
):
|
||||
completion_reason = "reader_thread_dead"
|
||||
# Drain any final messages the reader flushed before dying
|
||||
while not self._response_queue.empty():
|
||||
try:
|
||||
final_msg = self._response_queue.get_nowait()
|
||||
if final_msg.get("id") == request_id:
|
||||
if "error" in final_msg:
|
||||
error_data = final_msg["error"]
|
||||
yield Error(
|
||||
code=error_data.get("code", -1),
|
||||
message=error_data.get(
|
||||
"message", "Unknown error"
|
||||
),
|
||||
)
|
||||
else:
|
||||
result = final_msg.get("result", {})
|
||||
try:
|
||||
yield PromptResponse.model_validate(result)
|
||||
except ValidationError:
|
||||
pass
|
||||
break
|
||||
except Empty:
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"[ACP] Reader thread dead: "
|
||||
f"acp_session={session_id} events={events_yielded}"
|
||||
)
|
||||
break
|
||||
|
||||
# Send SSE keepalive if idle
|
||||
idle_time = time.time() - last_event_time
|
||||
if idle_time >= SSE_KEEPALIVE_INTERVAL:
|
||||
packet_logger.log_raw(
|
||||
"SSE-KEEPALIVE-YIELD",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"idle_seconds": idle_time,
|
||||
},
|
||||
)
|
||||
keepalive_count += 1
|
||||
yield SSEKeepalive()
|
||||
last_event_time = time.time() # Reset after yielding keepalive
|
||||
last_event_time = time.time()
|
||||
continue
|
||||
|
||||
# Check for response to our prompt request
|
||||
if message_data.get("id") == request_id:
|
||||
# Check for JSON-RPC response to our prompt request.
|
||||
msg_id = message_data.get("id")
|
||||
is_response = "method" not in message_data and (
|
||||
msg_id == request_id
|
||||
or (msg_id is not None and str(msg_id) == str(request_id))
|
||||
)
|
||||
if is_response:
|
||||
completion_reason = "jsonrpc_response"
|
||||
if "error" in message_data:
|
||||
error_data = message_data["error"]
|
||||
completion_reason = "jsonrpc_error"
|
||||
logger.warning(f"[ACP] Prompt error: {error_data}")
|
||||
packet_logger.log_jsonrpc_response(
|
||||
request_id, error=error_data, context="k8s"
|
||||
)
|
||||
@@ -498,26 +728,16 @@ class ACPExecClient:
|
||||
)
|
||||
try:
|
||||
prompt_response = PromptResponse.model_validate(result)
|
||||
packet_logger.log_acp_event_yielded(
|
||||
"prompt_response", prompt_response
|
||||
)
|
||||
events_yielded += 1
|
||||
yield prompt_response
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"type": "prompt_response", "error": str(e)},
|
||||
)
|
||||
logger.error(f"[ACP] PromptResponse validation failed: {e}")
|
||||
|
||||
# Log completion summary
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
packet_logger.log_raw(
|
||||
"ACP-SEND-MESSAGE-COMPLETE-K8S",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"events_yielded": events_yielded,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
logger.info(
|
||||
f"[ACP] Prompt complete: "
|
||||
f"reason={completion_reason} acp_session={session_id} "
|
||||
f"events={events_yielded} elapsed={elapsed_ms:.0f}ms"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -526,25 +746,15 @@ class ACPExecClient:
|
||||
params_data = message_data.get("params", {})
|
||||
update = params_data.get("update", {})
|
||||
|
||||
# Log the notification
|
||||
packet_logger.log_jsonrpc_notification(
|
||||
"session/update",
|
||||
{"update_type": update.get("sessionUpdate")},
|
||||
context="k8s",
|
||||
)
|
||||
|
||||
for event in self._process_session_update(update):
|
||||
events_yielded += 1
|
||||
# Log each yielded event
|
||||
event_type = self._get_event_type_name(event)
|
||||
packet_logger.log_acp_event_yielded(event_type, event)
|
||||
yield event
|
||||
|
||||
# Handle requests from agent - send error response
|
||||
elif "method" in message_data and "id" in message_data:
|
||||
packet_logger.log_raw(
|
||||
"ACP-UNSUPPORTED-REQUEST-K8S",
|
||||
{"method": message_data["method"], "id": message_data["id"]},
|
||||
logger.debug(
|
||||
f"[ACP] Unsupported agent request: "
|
||||
f"method={message_data['method']}"
|
||||
)
|
||||
self._send_error_response(
|
||||
message_data["id"],
|
||||
@@ -552,113 +762,47 @@ class ACPExecClient:
|
||||
f"Method not supported: {message_data['method']}",
|
||||
)
|
||||
|
||||
def _get_event_type_name(self, event: ACPEvent) -> str:
|
||||
"""Get the type name for an ACP event."""
|
||||
if isinstance(event, AgentMessageChunk):
|
||||
return "agent_message_chunk"
|
||||
elif isinstance(event, AgentThoughtChunk):
|
||||
return "agent_thought_chunk"
|
||||
elif isinstance(event, ToolCallStart):
|
||||
return "tool_call_start"
|
||||
elif isinstance(event, ToolCallProgress):
|
||||
return "tool_call_progress"
|
||||
elif isinstance(event, AgentPlanUpdate):
|
||||
return "agent_plan_update"
|
||||
elif isinstance(event, CurrentModeUpdate):
|
||||
return "current_mode_update"
|
||||
elif isinstance(event, PromptResponse):
|
||||
return "prompt_response"
|
||||
elif isinstance(event, Error):
|
||||
return "error"
|
||||
elif isinstance(event, SSEKeepalive):
|
||||
return "sse_keepalive"
|
||||
return "unknown"
|
||||
else:
|
||||
logger.warning(
|
||||
f"[ACP] Unhandled message: "
|
||||
f"id={message_data.get('id')} "
|
||||
f"method={message_data.get('method')} "
|
||||
f"keys={list(message_data.keys())}"
|
||||
)
|
||||
|
||||
def _process_session_update(
|
||||
self, update: dict[str, Any]
|
||||
) -> Generator[ACPEvent, None, None]:
|
||||
"""Process a session/update notification and yield typed ACP schema objects."""
|
||||
update_type = update.get("sessionUpdate")
|
||||
packet_logger = get_packet_logger()
|
||||
|
||||
if update_type == "agent_message_chunk":
|
||||
# Map update types to their ACP schema classes.
|
||||
# Note: prompt_response is intentionally excluded here — turn completion
|
||||
# is determined by the JSON-RPC response to session/prompt, not by a
|
||||
# session/update notification. This matches the ACP spec and Zed's impl.
|
||||
type_map: dict[str, type] = {
|
||||
"agent_message_chunk": AgentMessageChunk,
|
||||
"agent_thought_chunk": AgentThoughtChunk,
|
||||
"tool_call": ToolCallStart,
|
||||
"tool_call_update": ToolCallProgress,
|
||||
"plan": AgentPlanUpdate,
|
||||
"current_mode_update": CurrentModeUpdate,
|
||||
}
|
||||
|
||||
model_class = type_map.get(update_type) # type: ignore[arg-type]
|
||||
if model_class is not None:
|
||||
try:
|
||||
yield AgentMessageChunk.model_validate(update)
|
||||
yield model_class.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "agent_thought_chunk":
|
||||
try:
|
||||
yield AgentThoughtChunk.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "user_message_chunk":
|
||||
# Echo of user message - skip but log
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "user_message_chunk"}
|
||||
)
|
||||
|
||||
elif update_type == "tool_call":
|
||||
try:
|
||||
yield ToolCallStart.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "tool_call_update":
|
||||
try:
|
||||
yield ToolCallProgress.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "plan":
|
||||
try:
|
||||
yield AgentPlanUpdate.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "current_mode_update":
|
||||
try:
|
||||
yield CurrentModeUpdate.model_validate(update)
|
||||
except ValidationError as e:
|
||||
packet_logger.log_raw(
|
||||
"ACP-VALIDATION-ERROR-K8S",
|
||||
{"update_type": update_type, "error": str(e), "update": update},
|
||||
)
|
||||
|
||||
elif update_type == "available_commands_update":
|
||||
# Skip command updates
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "available_commands_update"}
|
||||
)
|
||||
|
||||
elif update_type == "session_info_update":
|
||||
# Skip session info updates
|
||||
packet_logger.log_raw(
|
||||
"ACP-SKIPPED-UPDATE-K8S", {"type": "session_info_update"}
|
||||
)
|
||||
|
||||
else:
|
||||
# Unknown update types are logged
|
||||
packet_logger.log_raw(
|
||||
"ACP-UNKNOWN-UPDATE-TYPE-K8S",
|
||||
{"update_type": update_type, "update": update},
|
||||
)
|
||||
logger.warning(f"[ACP] Validation error for {update_type}: {e}")
|
||||
elif update_type not in (
|
||||
"user_message_chunk",
|
||||
"available_commands_update",
|
||||
"session_info_update",
|
||||
"usage_update",
|
||||
"prompt_response",
|
||||
):
|
||||
logger.debug(f"[ACP] Unknown update type: {update_type}")
|
||||
|
||||
def _send_error_response(self, request_id: int, code: int, message: str) -> None:
|
||||
"""Send an error response to an agent request."""
|
||||
@@ -673,15 +817,24 @@ class ACPExecClient:
|
||||
|
||||
self._ws_client.write_stdin(json.dumps(response) + "\n")
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel the current operation."""
|
||||
if self._state.current_session is None:
|
||||
return
|
||||
def cancel(self, session_id: str | None = None) -> None:
|
||||
"""Cancel the current operation on a session.
|
||||
|
||||
self._send_notification(
|
||||
"session/cancel",
|
||||
{"sessionId": self._state.current_session.session_id},
|
||||
)
|
||||
Args:
|
||||
session_id: The ACP session ID to cancel. If None, cancels all sessions.
|
||||
"""
|
||||
if session_id:
|
||||
if session_id in self._state.sessions:
|
||||
self._send_notification(
|
||||
"session/cancel",
|
||||
{"sessionId": session_id},
|
||||
)
|
||||
else:
|
||||
for sid in self._state.sessions:
|
||||
self._send_notification(
|
||||
"session/cancel",
|
||||
{"sessionId": sid},
|
||||
)
|
||||
|
||||
def health_check(self, timeout: float = 5.0) -> bool: # noqa: ARG002
|
||||
"""Check if we can exec into the pod."""
|
||||
@@ -708,11 +861,9 @@ class ACPExecClient:
|
||||
return self._ws_client is not None and self._ws_client.is_open()
|
||||
|
||||
@property
|
||||
def session_id(self) -> str | None:
|
||||
"""Get the current session ID, if any."""
|
||||
if self._state.current_session:
|
||||
return self._state.current_session.session_id
|
||||
return None
|
||||
def session_ids(self) -> list[str]:
|
||||
"""Get all tracked session IDs."""
|
||||
return list(self._state.sessions.keys())
|
||||
|
||||
def __enter__(self) -> "ACPExecClient":
|
||||
"""Context manager entry."""
|
||||
|
||||
@@ -50,6 +50,7 @@ from pathlib import Path
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from acp.schema import PromptResponse
|
||||
from kubernetes import client # type: ignore
|
||||
from kubernetes import config
|
||||
from kubernetes.client.rest import ApiException # type: ignore
|
||||
@@ -97,6 +98,10 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# API server pod hostname — used to identify which replica is handling a request.
|
||||
# In K8s, HOSTNAME is set to the pod name (e.g., "api-server-dpgg7").
|
||||
_API_SERVER_HOSTNAME = os.environ.get("HOSTNAME", "unknown")
|
||||
|
||||
# Constants for pod configuration
|
||||
# Note: Next.js ports are dynamically allocated from SANDBOX_NEXTJS_PORT_START to
|
||||
# SANDBOX_NEXTJS_PORT_END range, with one port per session.
|
||||
@@ -348,6 +353,14 @@ class KubernetesSandboxManager(SandboxManager):
|
||||
self._service_account = SANDBOX_SERVICE_ACCOUNT_NAME
|
||||
self._file_sync_service_account = SANDBOX_FILE_SYNC_SERVICE_ACCOUNT
|
||||
|
||||
# One long-lived ACP client per sandbox (Zed-style architecture).
|
||||
# Multiple craft sessions share one `opencode acp` process per sandbox.
|
||||
self._acp_clients: dict[UUID, ACPExecClient] = {}
|
||||
|
||||
# Maps (sandbox_id, craft_session_id) → ACP session ID.
|
||||
# Each craft session has its own ACP session on the shared client.
|
||||
self._acp_session_ids: dict[tuple[UUID, UUID], str] = {}
|
||||
|
||||
# Load AGENTS.md template path
|
||||
build_dir = Path(__file__).parent.parent.parent # /onyx/server/features/build/
|
||||
self._agent_instructions_template_path = build_dir / "AGENTS.template.md"
|
||||
@@ -1156,11 +1169,28 @@ done
|
||||
def terminate(self, sandbox_id: UUID) -> None:
|
||||
"""Terminate a sandbox and clean up Kubernetes resources.
|
||||
|
||||
Deletes the Service and Pod for the sandbox.
|
||||
Stops the shared ACP client and removes all session mappings for this
|
||||
sandbox, then deletes the Service and Pod.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID to terminate
|
||||
"""
|
||||
# Stop the shared ACP client for this sandbox
|
||||
acp_client = self._acp_clients.pop(sandbox_id, None)
|
||||
if acp_client:
|
||||
try:
|
||||
acp_client.stop()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] Failed to stop ACP client for "
|
||||
f"sandbox {sandbox_id}: {e}"
|
||||
)
|
||||
|
||||
# Remove all session mappings for this sandbox
|
||||
keys_to_remove = [key for key in self._acp_session_ids if key[0] == sandbox_id]
|
||||
for key in keys_to_remove:
|
||||
del self._acp_session_ids[key]
|
||||
|
||||
# Clean up Kubernetes resources (needs string for pod/service names)
|
||||
self._cleanup_kubernetes_resources(str(sandbox_id))
|
||||
|
||||
@@ -1395,7 +1425,8 @@ echo "Session workspace setup complete"
|
||||
) -> None:
|
||||
"""Clean up a session workspace (on session delete).
|
||||
|
||||
Executes kubectl exec to remove the session directory.
|
||||
Removes the ACP session mapping and executes kubectl exec to remove
|
||||
the session directory. The shared ACP client persists for other sessions.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
@@ -1403,6 +1434,15 @@ echo "Session workspace setup complete"
|
||||
nextjs_port: Optional port where Next.js server is running (unused in K8s,
|
||||
we use PID file instead)
|
||||
"""
|
||||
# Remove the ACP session mapping (shared client persists)
|
||||
session_key = (sandbox_id, session_id)
|
||||
acp_session_id = self._acp_session_ids.pop(session_key, None)
|
||||
if acp_session_id:
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Removed ACP session mapping: "
|
||||
f"session={session_id} acp_session={acp_session_id}"
|
||||
)
|
||||
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
|
||||
@@ -1807,6 +1847,94 @@ echo "Session config regeneration complete"
|
||||
)
|
||||
return exec_client.health_check(timeout=timeout)
|
||||
|
||||
def _get_or_create_acp_client(self, sandbox_id: UUID) -> ACPExecClient:
|
||||
"""Get the shared ACP client for a sandbox, creating one if needed.
|
||||
|
||||
One long-lived `opencode acp` process per sandbox (Zed-style).
|
||||
If the existing client's WebSocket has died, replaces it with a new one.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
|
||||
Returns:
|
||||
A running ACPExecClient for this sandbox
|
||||
"""
|
||||
acp_client = self._acp_clients.get(sandbox_id)
|
||||
|
||||
if acp_client is not None and acp_client.is_running:
|
||||
return acp_client
|
||||
|
||||
# Client is dead or doesn't exist — clean up stale one
|
||||
if acp_client is not None:
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] Stale ACP client for sandbox {sandbox_id}, "
|
||||
f"replacing"
|
||||
)
|
||||
try:
|
||||
acp_client.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clear session mappings — they're invalid on a new process
|
||||
keys_to_remove = [
|
||||
key for key in self._acp_session_ids if key[0] == sandbox_id
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._acp_session_ids[key]
|
||||
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
new_client = ACPExecClient(
|
||||
pod_name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
)
|
||||
new_client.start(cwd="/workspace")
|
||||
self._acp_clients[sandbox_id] = new_client
|
||||
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Created shared ACP client: "
|
||||
f"sandbox={sandbox_id} pod={pod_name} "
|
||||
f"api_pod={_API_SERVER_HOSTNAME}"
|
||||
)
|
||||
return new_client
|
||||
|
||||
def _get_or_create_acp_session(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
acp_client: ACPExecClient,
|
||||
) -> str:
|
||||
"""Get the ACP session ID for a craft session, creating one if needed.
|
||||
|
||||
Uses the session mapping cache first, then falls back to
|
||||
`get_or_create_session()` which handles resume from opencode's
|
||||
persisted storage (multi-replica support).
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The craft session ID
|
||||
acp_client: The shared ACP client for this sandbox
|
||||
|
||||
Returns:
|
||||
The ACP session ID
|
||||
"""
|
||||
session_key = (sandbox_id, session_id)
|
||||
acp_session_id = self._acp_session_ids.get(session_key)
|
||||
|
||||
if acp_session_id and acp_session_id in acp_client.session_ids:
|
||||
return acp_session_id
|
||||
|
||||
# Session not tracked or was lost — get or create it
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
acp_session_id = acp_client.get_or_create_session(cwd=session_path)
|
||||
self._acp_session_ids[session_key] = acp_session_id
|
||||
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Session mapped: "
|
||||
f"craft_session={session_id} acp_session={acp_session_id}"
|
||||
)
|
||||
return acp_session_id
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
@@ -1815,8 +1943,9 @@ echo "Session config regeneration complete"
|
||||
) -> Generator[ACPEvent, None, None]:
|
||||
"""Send a message to the CLI agent and stream ACP events.
|
||||
|
||||
Runs `opencode acp` via kubectl exec in the sandbox pod.
|
||||
The agent runs in the session-specific workspace.
|
||||
Uses a shared ACP client per sandbox (one `opencode acp` process).
|
||||
Each craft session has its own ACP session ID on that shared process.
|
||||
Switching between sessions is client-side — just use the right sessionId.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
@@ -1827,37 +1956,53 @@ echo "Session config regeneration complete"
|
||||
Typed ACP schema event objects
|
||||
"""
|
||||
packet_logger = get_packet_logger()
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
|
||||
# Log ACP client creation
|
||||
packet_logger.log_acp_client_start(
|
||||
sandbox_id, session_id, session_path, context="k8s"
|
||||
# Get or create the shared ACP client for this sandbox
|
||||
acp_client = self._get_or_create_acp_client(sandbox_id)
|
||||
|
||||
# Get or create the ACP session for this craft session
|
||||
acp_session_id = self._get_or_create_acp_session(
|
||||
sandbox_id, session_id, acp_client
|
||||
)
|
||||
|
||||
exec_client = ACPExecClient(
|
||||
pod_name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] Sending message: "
|
||||
f"session={session_id} acp_session={acp_session_id} "
|
||||
f"api_pod={_API_SERVER_HOSTNAME}"
|
||||
)
|
||||
|
||||
# Log the send_message call at sandbox manager level
|
||||
packet_logger.log_session_start(session_id, sandbox_id, message)
|
||||
|
||||
events_count = 0
|
||||
got_prompt_response = False
|
||||
try:
|
||||
exec_client.start(cwd=session_path)
|
||||
for event in exec_client.send_message(message):
|
||||
for event in acp_client.send_message(message, session_id=acp_session_id):
|
||||
events_count += 1
|
||||
if isinstance(event, PromptResponse):
|
||||
got_prompt_response = True
|
||||
yield event
|
||||
|
||||
# Log successful completion
|
||||
logger.info(
|
||||
f"[SANDBOX-ACP] send_message completed: "
|
||||
f"session={session_id} events={events_count} "
|
||||
f"got_prompt_response={got_prompt_response}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id, success=True, events_count=events_count
|
||||
)
|
||||
except GeneratorExit:
|
||||
# Generator was closed by consumer (client disconnect, timeout, broken pipe)
|
||||
# This is the most common failure mode for SSE streaming
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] GeneratorExit: session={session_id} "
|
||||
f"events={events_count}, sending session/cancel"
|
||||
)
|
||||
try:
|
||||
acp_client.cancel(session_id=acp_session_id)
|
||||
except Exception as cancel_err:
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] session/cancel failed on GeneratorExit: "
|
||||
f"{cancel_err}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
@@ -1866,7 +2011,17 @@ echo "Session config regeneration complete"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log failure from normal exceptions
|
||||
logger.error(
|
||||
f"[SANDBOX-ACP] Exception: session={session_id} "
|
||||
f"events={events_count} error={e}, sending session/cancel"
|
||||
)
|
||||
try:
|
||||
acp_client.cancel(session_id=acp_session_id)
|
||||
except Exception as cancel_err:
|
||||
logger.warning(
|
||||
f"[SANDBOX-ACP] session/cancel failed on Exception: "
|
||||
f"{cancel_err}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
@@ -1875,19 +2030,16 @@ echo "Session config regeneration complete"
|
||||
)
|
||||
raise
|
||||
except BaseException as e:
|
||||
# Log failure from other base exceptions (SystemExit, KeyboardInterrupt, etc.)
|
||||
exception_type = type(e).__name__
|
||||
logger.error(
|
||||
f"[SANDBOX-ACP] {type(e).__name__}: session={session_id} " f"error={e}"
|
||||
)
|
||||
packet_logger.log_session_end(
|
||||
session_id,
|
||||
success=False,
|
||||
error=f"{exception_type}: {str(e) if str(e) else 'System-level interruption'}",
|
||||
error=f"{type(e).__name__}: {str(e) if str(e) else 'System-level interruption'}",
|
||||
events_count=events_count,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
exec_client.stop()
|
||||
# Log client stop
|
||||
packet_logger.log_acp_client_stop(sandbox_id, session_id, context="k8s")
|
||||
|
||||
def list_directory(
|
||||
self, sandbox_id: UUID, session_id: UUID, path: str
|
||||
|
||||
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Unit tests for Zed-style ACP session management in KubernetesSandboxManager.
|
||||
|
||||
These tests verify that the KubernetesSandboxManager correctly:
|
||||
- Maintains one shared ACPExecClient per sandbox
|
||||
- Maps craft sessions to ACP sessions on the shared client
|
||||
- Replaces dead clients and re-creates sessions
|
||||
- Cleans up on terminate/cleanup
|
||||
|
||||
All external dependencies (K8s, WebSockets, packet logging) are mocked.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The fully-qualified path to the module under test, used for patching
|
||||
_K8S_MODULE = "onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager"
|
||||
_ACP_CLIENT_CLASS = f"{_K8S_MODULE}.ACPExecClient"
|
||||
_GET_PACKET_LOGGER = f"{_K8S_MODULE}.get_packet_logger"
|
||||
|
||||
|
||||
def _make_mock_event() -> MagicMock:
|
||||
"""Create a mock ACP event."""
|
||||
return MagicMock(name="mock_acp_event")
|
||||
|
||||
|
||||
def _make_mock_client(
|
||||
is_running: bool = True,
|
||||
session_ids: list[str] | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock ACPExecClient with configurable state.
|
||||
|
||||
Args:
|
||||
is_running: Whether the client appears running
|
||||
session_ids: List of ACP session IDs the client tracks
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
type(mock_client).is_running = property(lambda _self: is_running)
|
||||
type(mock_client).session_ids = property(
|
||||
lambda _self: session_ids if session_ids is not None else []
|
||||
)
|
||||
mock_client.start.return_value = None
|
||||
mock_client.stop.return_value = None
|
||||
|
||||
# get_or_create_session returns a unique ACP session ID
|
||||
mock_client.get_or_create_session.return_value = f"acp-session-{uuid4().hex[:8]}"
|
||||
|
||||
mock_event = _make_mock_event()
|
||||
mock_client.send_message.return_value = iter([mock_event])
|
||||
return mock_client
|
||||
|
||||
|
||||
def _drain_generator(gen: Generator[Any, None, None]) -> list[Any]:
|
||||
"""Consume a generator and return all yielded values as a list."""
|
||||
return list(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture: fresh KubernetesSandboxManager instance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def manager() -> Generator[Any, None, None]:
|
||||
"""Create a fresh KubernetesSandboxManager instance with all externals mocked."""
|
||||
with (
|
||||
patch(f"{_K8S_MODULE}.config") as _mock_config,
|
||||
patch(f"{_K8S_MODULE}.client") as _mock_k8s_client,
|
||||
patch(f"{_K8S_MODULE}.k8s_stream"),
|
||||
patch(_GET_PACKET_LOGGER) as mock_get_logger,
|
||||
):
|
||||
mock_packet_logger = MagicMock()
|
||||
mock_get_logger.return_value = mock_packet_logger
|
||||
|
||||
_mock_config.load_incluster_config.return_value = None
|
||||
_mock_config.ConfigException = Exception
|
||||
|
||||
_mock_k8s_client.ApiClient.return_value = MagicMock()
|
||||
_mock_k8s_client.CoreV1Api.return_value = MagicMock()
|
||||
_mock_k8s_client.BatchV1Api.return_value = MagicMock()
|
||||
_mock_k8s_client.NetworkingV1Api.return_value = MagicMock()
|
||||
|
||||
from onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager import (
|
||||
KubernetesSandboxManager,
|
||||
)
|
||||
|
||||
KubernetesSandboxManager._instance = None
|
||||
mgr = KubernetesSandboxManager()
|
||||
|
||||
yield mgr
|
||||
|
||||
KubernetesSandboxManager._instance = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Shared client lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_send_message_creates_shared_client_on_first_call(manager: Any) -> None:
|
||||
"""First call to send_message() should create one shared ACPExecClient
|
||||
for the sandbox, create an ACP session, and yield events."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
message = "hello world"
|
||||
|
||||
mock_event = _make_mock_event()
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
acp_session_id = "acp-session-abc"
|
||||
mock_client.get_or_create_session.return_value = acp_session_id
|
||||
# session_ids must include the created session for validation
|
||||
type(mock_client).session_ids = property(lambda _: [acp_session_id])
|
||||
mock_client.send_message.return_value = iter([mock_event])
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
|
||||
events = _drain_generator(manager.send_message(sandbox_id, session_id, message))
|
||||
|
||||
# Verify shared client was constructed once
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# Verify start() was called with /workspace (not session-specific path)
|
||||
mock_client.start.assert_called_once_with(cwd="/workspace")
|
||||
|
||||
# Verify get_or_create_session was called with the session path
|
||||
expected_cwd = f"/workspace/sessions/{session_id}"
|
||||
mock_client.get_or_create_session.assert_called_once_with(cwd=expected_cwd)
|
||||
|
||||
# Verify send_message was called with correct args
|
||||
mock_client.send_message.assert_called_once_with(message, session_id=acp_session_id)
|
||||
|
||||
# Verify we got the event
|
||||
assert mock_event in events
|
||||
|
||||
# Verify shared client is cached by sandbox_id
|
||||
assert sandbox_id in manager._acp_clients
|
||||
assert manager._acp_clients[sandbox_id] is mock_client
|
||||
|
||||
# Verify session mapping exists
|
||||
assert (sandbox_id, session_id) in manager._acp_session_ids
|
||||
assert manager._acp_session_ids[(sandbox_id, session_id)] == acp_session_id
|
||||
|
||||
|
||||
def test_send_message_reuses_shared_client_for_same_session(manager: Any) -> None:
|
||||
"""Second call with the same session should reuse the shared client
|
||||
and the same ACP session ID."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
mock_event_1 = _make_mock_event()
|
||||
mock_event_2 = _make_mock_event()
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
acp_session_id = "acp-session-reuse"
|
||||
mock_client.get_or_create_session.return_value = acp_session_id
|
||||
type(mock_client).session_ids = property(lambda _: [acp_session_id])
|
||||
|
||||
mock_client.send_message.side_effect = [
|
||||
iter([mock_event_1]),
|
||||
iter([mock_event_2]),
|
||||
]
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
|
||||
events_1 = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id, "first")
|
||||
)
|
||||
events_2 = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id, "second")
|
||||
)
|
||||
|
||||
# Constructor called only ONCE (shared client)
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# start() called only once
|
||||
mock_client.start.assert_called_once()
|
||||
|
||||
# get_or_create_session called only once (second call uses cached mapping)
|
||||
mock_client.get_or_create_session.assert_called_once()
|
||||
|
||||
# send_message called twice with same ACP session ID
|
||||
assert mock_client.send_message.call_count == 2
|
||||
|
||||
assert mock_event_1 in events_1
|
||||
assert mock_event_2 in events_2
|
||||
|
||||
|
||||
def test_send_message_different_sessions_share_client(manager: Any) -> None:
|
||||
"""Two different craft sessions on the same sandbox should share the
|
||||
same ACPExecClient but have different ACP sessions."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id_a: UUID = uuid4()
|
||||
session_id_b: UUID = uuid4()
|
||||
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
acp_session_a = "acp-session-a"
|
||||
acp_session_b = "acp-session-b"
|
||||
mock_client.get_or_create_session.side_effect = [acp_session_a, acp_session_b]
|
||||
type(mock_client).session_ids = property(lambda _: [acp_session_a, acp_session_b])
|
||||
|
||||
mock_event_a = _make_mock_event()
|
||||
mock_event_b = _make_mock_event()
|
||||
mock_client.send_message.side_effect = [
|
||||
iter([mock_event_a]),
|
||||
iter([mock_event_b]),
|
||||
]
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=mock_client) as MockClass:
|
||||
events_a = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id_a, "msg a")
|
||||
)
|
||||
events_b = _drain_generator(
|
||||
manager.send_message(sandbox_id, session_id_b, "msg b")
|
||||
)
|
||||
|
||||
# Only ONE shared client was created
|
||||
MockClass.assert_called_once()
|
||||
|
||||
# get_or_create_session called twice (once per craft session)
|
||||
assert mock_client.get_or_create_session.call_count == 2
|
||||
|
||||
# send_message called with different ACP session IDs
|
||||
mock_client.send_message.assert_any_call("msg a", session_id=acp_session_a)
|
||||
mock_client.send_message.assert_any_call("msg b", session_id=acp_session_b)
|
||||
|
||||
# Both session mappings exist
|
||||
assert manager._acp_session_ids[(sandbox_id, session_id_a)] == acp_session_a
|
||||
assert manager._acp_session_ids[(sandbox_id, session_id_b)] == acp_session_b
|
||||
|
||||
assert mock_event_a in events_a
|
||||
assert mock_event_b in events_b
|
||||
|
||||
|
||||
def test_send_message_replaces_dead_client(manager: Any) -> None:
|
||||
"""If the shared client has is_running == False, should replace it and
|
||||
re-create sessions."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
# Place a dead client in the cache
|
||||
dead_client = _make_mock_client(is_running=False)
|
||||
manager._acp_clients[sandbox_id] = dead_client
|
||||
manager._acp_session_ids[(sandbox_id, session_id)] = "old-acp-session"
|
||||
|
||||
# Create the replacement client
|
||||
new_event = _make_mock_event()
|
||||
new_client = _make_mock_client(is_running=True)
|
||||
new_acp_session = "new-acp-session"
|
||||
new_client.get_or_create_session.return_value = new_acp_session
|
||||
type(new_client).session_ids = property(lambda _: [new_acp_session])
|
||||
new_client.send_message.return_value = iter([new_event])
|
||||
|
||||
with patch(_ACP_CLIENT_CLASS, return_value=new_client):
|
||||
events = _drain_generator(manager.send_message(sandbox_id, session_id, "test"))
|
||||
|
||||
# Dead client was stopped during replacement
|
||||
dead_client.stop.assert_called_once()
|
||||
|
||||
# New client was started
|
||||
new_client.start.assert_called_once()
|
||||
|
||||
# Old session mapping was cleared, new one created
|
||||
assert manager._acp_session_ids[(sandbox_id, session_id)] == new_acp_session
|
||||
|
||||
# Cache holds the new client
|
||||
assert manager._acp_clients[sandbox_id] is new_client
|
||||
|
||||
assert new_event in events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_terminate_stops_shared_client(manager: Any) -> None:
|
||||
"""terminate(sandbox_id) should stop the shared client and clear
|
||||
all session mappings for that sandbox."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id_1: UUID = uuid4()
|
||||
session_id_2: UUID = uuid4()
|
||||
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
manager._acp_clients[sandbox_id] = mock_client
|
||||
manager._acp_session_ids[(sandbox_id, session_id_1)] = "acp-1"
|
||||
manager._acp_session_ids[(sandbox_id, session_id_2)] = "acp-2"
|
||||
|
||||
with patch.object(manager, "_cleanup_kubernetes_resources"):
|
||||
manager.terminate(sandbox_id)
|
||||
|
||||
# Shared client was stopped
|
||||
mock_client.stop.assert_called_once()
|
||||
|
||||
# Client removed from cache
|
||||
assert sandbox_id not in manager._acp_clients
|
||||
|
||||
# Session mappings removed
|
||||
assert (sandbox_id, session_id_1) not in manager._acp_session_ids
|
||||
assert (sandbox_id, session_id_2) not in manager._acp_session_ids
|
||||
|
||||
|
||||
def test_terminate_leaves_other_sandbox_untouched(manager: Any) -> None:
|
||||
"""terminate(sandbox_A) should NOT affect sandbox_B's client or sessions."""
|
||||
sandbox_a: UUID = uuid4()
|
||||
sandbox_b: UUID = uuid4()
|
||||
session_a: UUID = uuid4()
|
||||
session_b: UUID = uuid4()
|
||||
|
||||
client_a = _make_mock_client(is_running=True)
|
||||
client_b = _make_mock_client(is_running=True)
|
||||
|
||||
manager._acp_clients[sandbox_a] = client_a
|
||||
manager._acp_clients[sandbox_b] = client_b
|
||||
manager._acp_session_ids[(sandbox_a, session_a)] = "acp-a"
|
||||
manager._acp_session_ids[(sandbox_b, session_b)] = "acp-b"
|
||||
|
||||
with patch.object(manager, "_cleanup_kubernetes_resources"):
|
||||
manager.terminate(sandbox_a)
|
||||
|
||||
# sandbox_a cleaned up
|
||||
client_a.stop.assert_called_once()
|
||||
assert sandbox_a not in manager._acp_clients
|
||||
assert (sandbox_a, session_a) not in manager._acp_session_ids
|
||||
|
||||
# sandbox_b untouched
|
||||
client_b.stop.assert_not_called()
|
||||
assert sandbox_b in manager._acp_clients
|
||||
assert manager._acp_session_ids[(sandbox_b, session_b)] == "acp-b"
|
||||
|
||||
|
||||
def test_cleanup_session_removes_session_mapping(manager: Any) -> None:
|
||||
"""cleanup_session_workspace() should remove the session mapping but
|
||||
leave the shared client alive for other sessions."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
mock_client = _make_mock_client(is_running=True)
|
||||
manager._acp_clients[sandbox_id] = mock_client
|
||||
manager._acp_session_ids[(sandbox_id, session_id)] = "acp-session-xyz"
|
||||
|
||||
with patch.object(manager, "_stream_core_api") as mock_stream_api:
|
||||
mock_stream_api.connect_get_namespaced_pod_exec = MagicMock()
|
||||
with patch(f"{_K8S_MODULE}.k8s_stream", return_value="cleanup ok"):
|
||||
manager.cleanup_session_workspace(sandbox_id, session_id)
|
||||
|
||||
# Session mapping removed
|
||||
assert (sandbox_id, session_id) not in manager._acp_session_ids
|
||||
|
||||
# Shared client is NOT stopped (other sessions may use it)
|
||||
mock_client.stop.assert_not_called()
|
||||
assert sandbox_id in manager._acp_clients
|
||||
|
||||
|
||||
def test_cleanup_session_handles_no_mapping(manager: Any) -> None:
|
||||
"""cleanup_session_workspace() should not error when there's no
|
||||
session mapping."""
|
||||
sandbox_id: UUID = uuid4()
|
||||
session_id: UUID = uuid4()
|
||||
|
||||
assert (sandbox_id, session_id) not in manager._acp_session_ids
|
||||
|
||||
with patch.object(manager, "_stream_core_api") as mock_stream_api:
|
||||
mock_stream_api.connect_get_namespaced_pod_exec = MagicMock()
|
||||
with patch(f"{_K8S_MODULE}.k8s_stream", return_value="cleanup ok"):
|
||||
manager.cleanup_session_workspace(sandbox_id, session_id)
|
||||
|
||||
assert (sandbox_id, session_id) not in manager._acp_session_ids
|
||||
Reference in New Issue
Block a user