Source code for nemo_evaluator.sandbox.ecs_fargate

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ECS Fargate sandbox.

Supports two transport modes over SSH:

  - **Exec-server mode** — one-way SSH tunnel.  An embedded HTTP exec
    server runs inside the container; the orchestrator drives all
    command execution, uploads, and downloads through it.

  - **Agent-server mode** — two-way SSH tunnel.  The container hosts a
    self-contained agent that reaches the model via a reverse tunnel;
    the orchestrator connects to the agent's API via a forward tunnel.

Includes Docker image building via AWS CodeBuild with ECR caching.
"""

from __future__ import annotations

import atexit
import base64
import hashlib
import io
import json
import os
import random
import re
import shlex
import socket
import subprocess
import tarfile
import tempfile
import threading
import time
import uuid
import zipfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Mapping, TypeVar
from urllib.parse import ParseResult, urlparse

import structlog

from nemo_evaluator.sandbox.base import ExecResult, OutsideEndpoint

log = structlog.get_logger(__name__)
T = TypeVar("T")


# =====================================================================
# Lazy AWS SDK import
# =====================================================================


def _require_aws_sdks():
    """Import boto3/botocore only when actually needed."""
    try:
        import importlib

        boto3 = importlib.import_module("boto3")
        botocore_config = importlib.import_module("botocore.config")
        botocore_exceptions = importlib.import_module("botocore.exceptions")
    except ModuleNotFoundError as e:
        raise RuntimeError(
            "ECS Fargate sandbox requires boto3/botocore. "
            "Install them (`pip install boto3`) or use a different sandbox backend."
        ) from e

    Config = getattr(botocore_config, "Config")
    ClientError = getattr(botocore_exceptions, "ClientError")
    return boto3, Config, ClientError


# =====================================================================
# Config dataclasses
# =====================================================================


def _coerce_list(value: Any) -> list[str]:
    if value is None:
        return []
    if isinstance(value, list):
        return [str(v) for v in value]
    if isinstance(value, str):
        return [value]
    raise TypeError(f"expected list or str, got {type(value)!r}")


def _sanitize_id(value: str, max_len: int = 100) -> str:
    cleaned = re.sub(r"[^a-zA-Z0-9-]+", "-", value).strip("-")
    return cleaned[:max_len] or "task"


[docs] @dataclass(frozen=True) class SshSidecarConfig: """SSH sidecar container configuration. The sidecar runs sshd in an Alpine container alongside the main container, providing SSH-based transport for tunnelling and execution. Two modes depending on ``exec_server_port``: * **Exec-server mode** (``exec_server_port`` is set): One-way SSH tunnel. The sandbox uploads a zero-dependency HTTP exec server into the main container and forwards a local port to it. ``exec()``, ``upload()``, ``download()`` all work. * **Agent-server mode** (``exec_server_port`` is ``None``): Two-way SSH tunnel. A reverse tunnel (``-R``) makes the :class:`OutsideEndpoint` reachable inside the task; a forward tunnel (``-L``) gives the orchestrator access to the agent server. The consumer is responsible for command execution via its own agent API. """ sshd_port: int = 2222 ssh_ready_timeout_sec: float = 120.0 public_key_secret_arn: str = "" # required — pre-provisioned only private_key_secret_arn: str = "" # required — pre-provisioned only image: str | None = None # sidecar image (None → alpine:latest) # Exec server config (exec-server mode; None → agent-server mode) exec_server_port: int | None = None
[docs] @classmethod def from_dict(cls, raw: Mapping[str, Any]) -> SshSidecarConfig: return cls( sshd_port=int(raw.get("sshd_port", 2222)), ssh_ready_timeout_sec=float(raw.get("ssh_ready_timeout_sec", 120.0)), public_key_secret_arn=str(raw.get("public_key_secret_arn", "")), private_key_secret_arn=str(raw.get("private_key_secret_arn", "")), image=raw.get("image"), exec_server_port=( int(raw["exec_server_port"]) if raw.get("exec_server_port") is not None else None ), )
[docs] @dataclass(frozen=True) class EcsFargateConfig: """Configuration for the ECS Fargate sandbox.""" # AWS infrastructure region: str | None = None cluster: str = "" subnets: list[str] = field(default_factory=list) security_groups: list[str] = field(default_factory=list) assign_public_ip: bool = False # Task definition task_definition: str | None = None # base task def to clone task_definition_family_prefix: str = "ecs-sandbox" image_template: str | None = None # supports {task_id}, {task_id_sanitized} container_name: str = "main" container_port: int | None = None # agent-server port (agent-server mode) cpu: str = "4096" memory: str = "8192" ephemeral_storage_gib: int | None = None platform_version: str | None = None execution_role_arn: str | None = None task_role_arn: str | None = None extra_env: dict[str, str] | None = None # Logging log_group: str | None = None log_stream_prefix: str | None = None # Lifecycle max_task_lifetime_sec: int = 14400 # 4 h startup_timeout_sec: float = 300.0 poll_interval_sec: float = 2.0 run_task_max_retries: int = 30 # SSH sidecar ssh_sidecar: SshSidecarConfig | None = None # S3 file staging s3_bucket: str | None = None s3_prefix: str | None = None # Docker build via AWS CodeBuild ecr_repository: str | None = None environment_dir: str | None = None codebuild_project: str | None = None codebuild_service_role: str | None = None codebuild_compute_type: str = "BUILD_GENERAL1_MEDIUM" codebuild_build_timeout: int = 30 dockerhub_secret_arn: str | None = None build_parallelism: int = 50
[docs] @classmethod def from_dict(cls, raw: Mapping[str, Any]) -> EcsFargateConfig: subnets = _coerce_list(raw.get("subnets")) sgs = _coerce_list(raw.get("security_groups")) has_sidecar = isinstance(raw.get("ssh_sidecar"), Mapping) explicit_ip = bool(raw.get("assign_public_ip", False)) assign_public_ip = explicit_ip or has_sidecar if has_sidecar and not explicit_ip: log.info( "assign_public_ip forced True because ssh_sidecar is configured " "(SSH requires a reachable IP)" ) return cls( region=raw.get("region"), cluster=str(raw.get("cluster", "")), subnets=subnets, security_groups=sgs, assign_public_ip=assign_public_ip, task_definition=raw.get("task_definition"), task_definition_family_prefix=str( raw.get("task_definition_family_prefix", "ecs-sandbox") ), image_template=raw.get("image_template"), container_name=str(raw.get("container_name", "main")), container_port=( int(raw["container_port"]) if raw.get("container_port") is not None else None ), cpu=str(raw.get("cpu", "4096")), memory=str(raw.get("memory", "8192")), ephemeral_storage_gib=( int(raw["ephemeral_storage_gib"]) if raw.get("ephemeral_storage_gib") is not None else None ), platform_version=raw.get("platform_version"), execution_role_arn=raw.get("execution_role_arn"), task_role_arn=raw.get("task_role_arn"), extra_env=( {str(k): str(v) for k, v in raw["extra_env"].items()} if isinstance(raw.get("extra_env"), Mapping) else None ), log_group=raw.get("log_group"), log_stream_prefix=raw.get("log_stream_prefix"), max_task_lifetime_sec=int(raw.get("max_task_lifetime_sec", 14400)), startup_timeout_sec=float(raw.get("startup_timeout_sec", 300.0)), poll_interval_sec=float(raw.get("poll_interval_sec", 2.0)), run_task_max_retries=int(raw.get("run_task_max_retries", 30)), ssh_sidecar=( SshSidecarConfig.from_dict(raw["ssh_sidecar"]) if has_sidecar else None ), s3_bucket=raw.get("s3_bucket"), s3_prefix=raw.get("s3_prefix"), ecr_repository=raw.get("ecr_repository"), environment_dir=raw.get("environment_dir"), codebuild_project=raw.get("codebuild_project"), codebuild_service_role=raw.get("codebuild_service_role"), codebuild_compute_type=str( raw.get("codebuild_compute_type", "BUILD_GENERAL1_MEDIUM") ), codebuild_build_timeout=int(raw.get("codebuild_build_timeout", 30)), dockerhub_secret_arn=raw.get("dockerhub_secret_arn"), build_parallelism=max(1, int(raw.get("build_parallelism", 50))), )
# ===================================================================== # Retry utilities # ===================================================================== _RETRYABLE_CODES = frozenset( { "ThrottlingException", "TooManyRequestsException", "ServiceUnavailable", "RequestLimitExceeded", } ) _RETRYABLE_MESSAGES = ( "capacity is unavailable", "rate exceeded", "too many concurrent", "throttl", "connect timeout", "read timeout", "connection reset", "endpointconnectionerror", ) def _is_retryable_error(exc: Exception) -> bool: """Return *True* if *exc* looks like a transient AWS error.""" msg = str(exc).lower() code = "" if hasattr(exc, "response"): code = (exc.response.get("Error") or {}).get("Code", "") # type: ignore[union-attr] return code in _RETRYABLE_CODES or any(m in msg for m in _RETRYABLE_MESSAGES) def _retry_with_backoff( func: Callable[[], T], *, operation_name: str, max_retries: int | None = None, base_delay: float = 1.0, max_delay: float = 60.0, jitter: float = 0.5, ) -> T: """Call *func* with exponential back-off on retryable errors. *max_retries* = ``None`` means retry indefinitely, ``0`` means no retries (single attempt only). """ attempt = 0 while True: try: return func() except Exception as exc: if not _is_retryable_error(exc): raise attempt += 1 if max_retries is not None and attempt > max_retries: log.error(f"{operation_name} failed after {attempt - 1} retries: {exc}") raise delay = min(base_delay * (2 ** (attempt - 1)), max_delay) delay *= 1 + random.uniform(-jitter, jitter) log.warning( f"{operation_name} throttled (attempt {attempt}), retrying in {delay:.1f}s: {exc}" ) time.sleep(delay) # ===================================================================== # SSH helpers — secrets & port allocation # ===================================================================== def _free_port() -> int: """Allocate an ephemeral TCP port.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] def download_secret_to_file(secret_arn: str, region: str | None = None) -> str: """Fetch a Secrets Manager secret and write it to a temp file (mode 0600). The caller is responsible for deleting the file. """ key_material = download_secret_to_string(secret_arn, region=region) fd, path = tempfile.mkstemp(prefix="ecs-ssh-", suffix=".key") try: os.write(fd, key_material.encode()) finally: os.close(fd) os.chmod(path, 0o600) log.debug(f"Downloaded SSH key to {path}") return path def download_secret_to_string(secret_arn: str, region: str | None = None) -> str: """Fetch a Secrets Manager secret and return it as a string.""" boto3, *_ = _require_aws_sdks() sm = boto3.client("secretsmanager", region_name=region) resp = sm.get_secret_value(SecretId=secret_arn) return resp["SecretString"] # ===================================================================== # SSH tunnel # =====================================================================
[docs] class SshTunnel: """Manages an ``ssh -N`` subprocess with ``-L`` and/or ``-R`` tunnels. Two usage patterns: **Exec-server mode** — forward a single remote port:: tunnel = SshTunnel(host=ip, port=2222, key_file=key, forward_port=19542) tunnel.open() # tunnel.local_port → auto-allocated ephemeral port **Agent-server mode** — explicit forward + reverse specs:: fwd = _free_port() tunnel = SshTunnel(host=ip, port=2222, key_file=key, forwards=[f"{fwd}:localhost:8000"], reverses=[f"11434:model-host:11434"]) tunnel.open() # tunnel.local_port → fwd """ def __init__( self, *, host: str, port: int = 2222, user: str = "root", key_file: str, forward_port: int | None = None, forwards: list[str] | None = None, reverses: list[str] | None = None, local_port_override: int | None = None, ) -> None: self._host = host self._port = port self._user = user self._key_file = key_file self._simple_forward_port = forward_port self._forwards = list(forwards or []) self._reverses = list(reverses or []) self._local_port: int | None = local_port_override self._proc: subprocess.Popen[bytes] | None = None @property def local_port(self) -> int: if self._local_port is None: raise RuntimeError("Tunnel not open yet — call open() first") return self._local_port @property def is_open(self) -> bool: return self._proc is not None and self._proc.poll() is None
[docs] def open(self, *, max_retries: int = 15, initial_backoff: float = 5.0) -> None: """Start the SSH tunnel with retries for transient connection errors.""" if self.is_open: return # In simple mode, allocate a fresh local port per attempt. use_simple = self._simple_forward_port is not None last_err = "" backoff = initial_backoff for attempt in range(1, max_retries + 1): if use_simple: self._local_port = _free_port() cmd = self._build_ssh_cmd() log.info(f"SSH tunnel attempt {attempt}/{max_retries}: {' '.join(cmd)}") self._proc = subprocess.Popen( cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, ) time.sleep(3) if self._proc.poll() is None: # Process alive — verify forward port is reachable. if self._local_port: try: self._wait_for_local_port(self._local_port, timeout=15.0) except Exception as port_exc: log.warning( f"SSH alive but forward port {self._local_port} not open: {port_exc}" ) self._kill() last_err = str(port_exc) time.sleep(min(5.0, attempt * 1.5)) continue log.info( f"SSH tunnel started (pid={self._proc.pid}, attempt {attempt}/{max_retries})" ) return stderr = ( self._proc.stderr.read().decode(errors="replace") if self._proc.stderr else "" ) last_err = stderr.strip() self._proc = None if not any( m in last_err for m in ( "Connection refused", "Connection timed out", "No route to host", "Connection reset", ) ): raise RuntimeError( f"SSH tunnel exited immediately (attempt {attempt}): {last_err}" ) log.warning( f"SSH tunnel attempt {attempt}/{max_retries} failed: {last_err} — retrying in {backoff:.0f}s" ) time.sleep(backoff) backoff = min(30.0, backoff * 1.5) raise RuntimeError( f"SSH tunnel failed after {max_retries} attempts: {last_err}" )
[docs] def close(self) -> None: """Terminate the SSH tunnel subprocess.""" self._kill()
[docs] def wait_ready( self, *, health_url: str | None = None, timeout: float = 120.0 ) -> None: """Wait until the tunnel endpoint is reachable. If *health_url* is given, polls ``GET <url>`` for HTTP 200. Otherwise just checks that the local port is open. """ if health_url: self._poll_health(health_url, timeout) elif self._local_port: self._wait_for_local_port(self._local_port, timeout)
[docs] def check_health(self) -> bool: """Return *True* if the SSH process is still alive.""" return self.is_open
# Context manager ------------------------------------------------- def __enter__(self) -> SshTunnel: self.open() return self def __exit__(self, *exc: object) -> None: self.close() # Internals ------------------------------------------------------- def _build_ssh_cmd(self) -> list[str]: cmd = [ "ssh", "-N", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-o", "ServerAliveInterval=30", "-o", "ServerAliveCountMax=5", "-o", "ConnectTimeout=15", "-o", "ExitOnForwardFailure=yes", "-o", "LogLevel=ERROR", "-i", self._key_file, "-p", str(self._port), ] if self._simple_forward_port is not None: cmd += [ "-L", f"127.0.0.1:{self._local_port}:127.0.0.1:{self._simple_forward_port}", ] for spec in self._forwards: cmd += ["-L", spec] for spec in self._reverses: cmd += ["-R", spec] cmd.append(f"{self._user}@{self._host}") return cmd def _kill(self) -> None: if self._proc is None: return try: self._proc.terminate() try: self._proc.wait(timeout=5) except subprocess.TimeoutExpired: self._proc.kill() log.info(f"SSH tunnel closed (pid={self._proc.pid})") except ProcessLookupError: pass finally: self._proc = None def _wait_for_local_port(self, port: int, timeout: float = 30.0) -> None: deadline = time.monotonic() + timeout while time.monotonic() < deadline: if self._proc and self._proc.poll() is not None: raise RuntimeError("SSH tunnel process exited while waiting for port") try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.settimeout(1.0) s.connect(("127.0.0.1", port)) return except OSError: time.sleep(0.3) raise TimeoutError(f"Local port 127.0.0.1:{port} not open after {timeout:.0f}s") def _poll_health(self, url: str, timeout: float) -> None: import urllib.error import urllib.request deadline = time.monotonic() + timeout attempt = 0 while time.monotonic() < deadline: attempt += 1 if not self.is_open: raise RuntimeError("SSH tunnel died while waiting for health endpoint") try: req = urllib.request.Request(url, method="GET") with urllib.request.urlopen(req, timeout=5) as resp: if resp.status == 200: log.info(f"Health endpoint ready (attempt {attempt}): {url}") return except (urllib.error.URLError, OSError, TimeoutError): pass time.sleep(min(3.0, 1.0 + attempt * 0.5)) raise TimeoutError(f"Health endpoint not reachable after {timeout:.0f}s: {url}")
# ===================================================================== # SSH sidecar container builder # ===================================================================== def build_ssh_sidecar_container( sidecar_cfg: SshSidecarConfig, *, public_key_value: str, max_lifetime_sec: int, log_group: str | None = None, log_region: str = "us-east-1", log_stream_prefix: str = "ecs-sandbox", ) -> dict[str, Any]: """Return an ECS container definition dict for the SSH sidecar. The sidecar: - Installs openssh-server on Alpine (< 2 s). - Receives the SSH public key as a plain environment variable (downloaded by the orchestrator from Secrets Manager). - Runs sshd as PID 1 (foreground) with a background watchdog for TTL. - Has a health check (``nc -z localhost <port>``). """ port = sidecar_cfg.sshd_port image = sidecar_cfg.image or "alpine:latest" sshd_cfg = ( f"Port {port}\\n" "PermitRootLogin prohibit-password\\n" "PasswordAuthentication no\\n" "AllowTcpForwarding yes\\n" "PermitListen any\\n" "GatewayPorts clientspecified\\n" "X11Forwarding no\\n" "PrintMotd no\\n" "LogLevel ERROR\\n" "ClientAliveInterval 15\\n" "ClientAliveCountMax 3\\n" "TCPKeepAlive yes\\n" "UseDNS no\\n" "MaxSessions 50\\n" ) watchdog = "" if max_lifetime_sec > 0: watchdog = ( f"( sleep {max_lifetime_sec}; " f"echo 'sidecar watchdog: TTL ({max_lifetime_sec}s) reached'; " "kill 1 2>/dev/null; sleep 3; kill -9 1 2>/dev/null ) & " ) sshd_cmd = ( "set -e; " "apk add --no-cache openssh-server netcat-openbsd; " "mkdir -p /root/.ssh; chmod 700 /root/.ssh; " 'printf "%s\\n" "$SSH_PUBLIC_KEY" > /root/.ssh/authorized_keys; ' "chmod 600 /root/.ssh/authorized_keys; " "ssh-keygen -A; " f"printf '{sshd_cfg}' > /etc/ssh/sshd_config; " f"{watchdog}" f"exec /usr/sbin/sshd -D -e -p {port}" ) container: dict[str, Any] = { "name": "ssh-tunnel", "image": image, "essential": True, "entryPoint": ["sh", "-c"], "command": [sshd_cmd], "environment": [ {"name": "SSH_PUBLIC_KEY", "value": public_key_value}, ], "healthCheck": { "command": ["CMD-SHELL", f"nc -z localhost {port} || exit 1"], "interval": 5, "timeout": 3, "retries": 10, "startPeriod": 30, }, } if log_group: container["logConfiguration"] = { "logDriver": "awslogs", "options": { "awslogs-group": log_group, "awslogs-region": log_region, "awslogs-stream-prefix": f"{log_stream_prefix}-tunnel", "awslogs-create-group": "true", }, } return container # ===================================================================== # Exec server — embedded script + HTTP client # ===================================================================== EXEC_SERVER_SCRIPT = r'''#!/usr/bin/env python3 """Zero-dependency HTTP exec server for sandbox containers. Endpoints: POST /exec {"cmd":"...","timeout":300} -> {"stdout":"...","stderr":"...","rc":0} POST /upload {"path":"/dst","content":"<b64>","mode":"0755"} -> {"ok":true} GET /download?path=/file -> raw bytes GET /health -> {"ok":true} Binds to 127.0.0.1 only (never network-exposed). """ from __future__ import annotations import base64, json, os, subprocess from http.server import BaseHTTPRequestHandler, HTTPServer from urllib.parse import parse_qs, urlparse _PORT = int(os.environ.get("TB_EXEC_PORT", "19542")) _BIND = os.environ.get("TB_EXEC_BIND", "127.0.0.1") class _H(BaseHTTPRequestHandler): def log_message(self, fmt, *a): pass def do_GET(self): p = urlparse(self.path) if p.path == "/health": self._ok({"ok": True}) elif p.path == "/download": qs = parse_qs(p.query) paths = qs.get("path", []) if not paths: self._err(400, "missing ?path=") else: self._dl(paths[0]) else: self._err(404, f"not found: {p.path}") def do_POST(self): p = urlparse(self.path) body = self._body() if p.path == "/exec": self._exec(body) elif p.path == "/upload": self._up(body) else: self._err(404, f"not found: {p.path}") def _exec(self, b): cmd = b.get("cmd") if not cmd: self._err(400, "missing 'cmd'"); return t = b.get("timeout", 300) try: cp = subprocess.run(cmd, shell=True, capture_output=True, timeout=t) self._ok({"stdout": cp.stdout.decode("utf-8", errors="replace"), "stderr": cp.stderr.decode("utf-8", errors="replace"), "rc": cp.returncode}) except subprocess.TimeoutExpired: self._ok({"stdout":"","stderr":f"timed out after {t}s","rc":124}) except Exception as e: self._ok({"stdout":"","stderr":str(e),"rc":-1}) def _up(self, b): path, c = b.get("path"), b.get("content") if not path or c is None: self._err(400, "missing path/content"); return try: data = base64.b64decode(c) os.makedirs(os.path.dirname(path) or ".", exist_ok=True) with open(path, "wb") as f: f.write(data) m = b.get("mode") if m: os.chmod(path, int(m, 8)) self._ok({"ok": True}) except Exception as e: self._err(500, str(e)) def _dl(self, path): if not os.path.isfile(path): self._err(404, f"not found: {path}"); return try: with open(path, "rb") as f: data = f.read() self.send_response(200) self.send_header("Content-Type", "application/octet-stream") self.send_header("Content-Length", str(len(data))) self.end_headers(); self.wfile.write(data) except Exception as e: self._err(500, str(e)) def _body(self): n = int(self.headers.get("Content-Length", 0)) if n == 0: return {} try: return json.loads(self.rfile.read(n)) except Exception: return {} def _ok(self, obj): p = json.dumps(obj).encode() self.send_response(200) self.send_header("Content-Type","application/json") self.send_header("Content-Length",str(len(p))) self.end_headers(); self.wfile.write(p) def _err(self, code, msg): p = json.dumps({"error": msg}).encode() self.send_response(code) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", str(len(p))) self.end_headers(); self.wfile.write(p) if __name__ == "__main__": s = HTTPServer((_BIND, _PORT), _H) print(f"exec_server on {_BIND}:{_PORT}", flush=True) try: s.serve_forever() except KeyboardInterrupt: pass finally: s.server_close() ''' # Transient network errors that the ExecClient should retry. _TRANSIENT_ERRORS = ( ConnectionResetError, ConnectionRefusedError, ConnectionAbortedError, BrokenPipeError, TimeoutError, OSError, )
[docs] class ExecClient: """HTTP client for the exec server running inside the container. Communicates through the SSH tunnel (``127.0.0.1:<local_port>``). Uses only stdlib ``urllib.request`` — no extra dependencies. """ def __init__(self, *, port: int, connect_timeout: float = 30.0) -> None: self._base = f"http://127.0.0.1:{port}" self._timeout = connect_timeout
[docs] def exec(self, cmd: str, *, timeout: int = 300) -> ExecResult: resp = self._post("/exec", {"cmd": cmd, "timeout": timeout}) return ExecResult( stdout=resp.get("stdout", ""), stderr=resp.get("stderr", ""), return_code=resp.get("rc", -1), )
[docs] def upload( self, remote_path: str, data: bytes | Path, *, mode: str | None = None, max_retries: int = 3, ) -> None: if isinstance(data, Path): data = data.read_bytes() body: dict[str, Any] = { "path": remote_path, "content": base64.b64encode(data).decode(), } if mode is not None: body["mode"] = mode payload_mb = len(body["content"]) / (1024 * 1024) upload_timeout = max(self._timeout, 60.0 + payload_mb * 2.0) last_err: Exception | None = None for attempt in range(1, max_retries + 1): try: resp = self._post("/upload", body, timeout_override=upload_timeout) if not resp.get("ok"): raise RuntimeError(f"upload to {remote_path} failed: {resp}") return except (TimeoutError, OSError, RuntimeError) as exc: last_err = exc if attempt < max_retries: log.warning( f"upload {remote_path} attempt {attempt}/{max_retries}: {exc}" ) time.sleep(2.0 * attempt) raise RuntimeError( f"upload to {remote_path} failed after {max_retries} attempts: {last_err}" )
[docs] def download(self, remote_path: str, *, max_retries: int = 3) -> bytes: import urllib.parse url = f"{self._base}/download?path={urllib.parse.quote(remote_path)}" return self._request( label=f"download {remote_path}", url=url, method="GET", timeout=self._timeout, max_retries=max_retries, )
[docs] def health(self) -> bool: try: self._request( label="health", url=f"{self._base}/health", method="GET", timeout=5, max_retries=1, ) return True except (ConnectionError, OSError, TimeoutError, RuntimeError): return False
def _post( self, path: str, body: dict[str, Any], *, timeout_override: float | None = None, max_retries: int = 4, ) -> dict[str, Any]: url = f"{self._base}{path}" payload = json.dumps(body).encode() if timeout_override is not None: http_timeout = timeout_override else: cmd_timeout = body.get("timeout") http_timeout = ( max(self._timeout, cmd_timeout + 30) if isinstance(cmd_timeout, (int, float)) else self._timeout ) raw = self._request( label=f"POST {path}", url=url, method="POST", data=payload, headers={"Content-Type": "application/json"}, timeout=http_timeout, max_retries=max_retries, ) return json.loads(raw) def _request( self, *, label: str, url: str, method: str, data: bytes | None = None, headers: dict[str, str] | None = None, timeout: float, max_retries: int, ) -> bytes: """Issue an HTTP request with retries on transient errors. Returns the raw response body as bytes. Raises ``RuntimeError`` on HTTP errors and ``ConnectionError`` when all retries are exhausted. """ import urllib.error import urllib.request last_err: Exception | None = None for attempt in range(1, max_retries + 1): req = urllib.request.Request(url, data=data, method=method) if headers: for k, v in headers.items(): req.add_header(k, v) try: with urllib.request.urlopen(req, timeout=timeout) as resp: return resp.read() except urllib.error.HTTPError as exc: body = exc.read().decode(errors="replace") raise RuntimeError(f"{label} failed (HTTP {exc.code}): {body}") from exc except (*_TRANSIENT_ERRORS, urllib.error.URLError) as exc: last_err = exc if attempt < max_retries: wait = min(15.0, 2.0 ** (attempt - 1)) log.warning( f"{label} attempt {attempt}/{max_retries}: {exc} — retry in {wait:.1f}s" ) time.sleep(wait) continue raise ConnectionError( f"{label} failed after {max_retries} attempts: {last_err}" ) from last_err raise ConnectionError(f"{label} unreachable")
# ===================================================================== # Image builder — AWS CodeBuild + ECR caching # =====================================================================
[docs] class ImageBuilder: """Build Docker images via AWS CodeBuild and push to ECR. Features: - Content-hash based ECR tags for automatic caching. - Build deduplication across concurrent tasks (only one build per tag). - Semaphore-based concurrency control. """ _lock = threading.Lock() _inflight_builds: dict[str, threading.Event] = {} _build_semaphore: threading.Semaphore | None = None _build_semaphore_size: int = 0
[docs] @staticmethod def get_ecr_image_tag(environment_dir: str | Path, environment_name: str) -> str: """``<name>__<content_hash[:8]>`` — deterministic, cache-friendly.""" h = hashlib.sha256() root = Path(environment_dir) for p in sorted(root.rglob("*")): if not p.is_file(): continue h.update(str(p.relative_to(root)).encode()) h.update(p.read_bytes()) return f"{environment_name}__{h.hexdigest()[:8]}"
[docs] @staticmethod def image_exists_in_ecr( ecr_repository: str, tag: str, region: str | None = None ) -> bool: boto3, _, ClientError = _require_aws_sdks() ecr = boto3.client("ecr", region_name=region) repo_name = ( ecr_repository.split("/", 1)[1] if "/" in ecr_repository else ecr_repository ) try: ecr.describe_images(repositoryName=repo_name, imageIds=[{"imageTag": tag}]) return True except ClientError as exc: code = exc.response.get("Error", {}).get("Code", "") if code in ("ImageNotFoundException", "RepositoryNotFoundException"): return False raise
[docs] @classmethod def ensure_image_built( cls, *, cfg: EcsFargateConfig, environment_name: str, force_build: bool = False, ) -> str: """Build and push if needed. Returns the full ECR image URL. Safe to call from many threads — deduplication and a semaphore ensure only one CodeBuild job runs per content-hash tag. """ ecr_repo = cfg.ecr_repository env_dir = cfg.environment_dir if not ecr_repo or not env_dir: raise ValueError( "ecr_repository and environment_dir are required for image building" ) tag = cls.get_ecr_image_tag(env_dir, environment_name) image_url = f"{ecr_repo}:{tag}" # --- Dedup: check if another thread is already building this tag --- with cls._lock: if tag in cls._inflight_builds: event = cls._inflight_builds[tag] log.info(f"Build already in progress for {tag} — waiting") waiting = True else: waiting = False if waiting: event.wait() log.info(f"Build finished (by another thread): {tag}") return image_url # --- ECR cache check --- if not force_build and cls.image_exists_in_ecr(ecr_repo, tag, cfg.region): log.info(f"ECR cache hit — skipping build: {image_url}") return image_url # --- Register as builder --- event = threading.Event() with cls._lock: if tag in cls._inflight_builds: existing = cls._inflight_builds[tag] else: cls._inflight_builds[tag] = event existing = None if existing is not None: log.info("Build started by another thread — joining") existing.wait() return image_url # --- Initialise semaphore --- with cls._lock: if ( cls._build_semaphore is None or cls._build_semaphore_size != cfg.build_parallelism ): cls._build_semaphore = threading.Semaphore(cfg.build_parallelism) cls._build_semaphore_size = cfg.build_parallelism try: log.info( f"Waiting for CodeBuild slot (parallelism={cfg.build_parallelism})" ) cls._build_semaphore.acquire() # type: ignore[union-attr] try: if not force_build and cls.image_exists_in_ecr( ecr_repo, tag, cfg.region ): log.info(f"ECR cache hit (after slot): {image_url}") return image_url cls._build_and_push( cfg=cfg, environment_name=environment_name, tag=tag, image_url=image_url, ) finally: cls._build_semaphore.release() # type: ignore[union-attr] finally: event.set() with cls._lock: cls._inflight_builds.pop(tag, None) return image_url
@staticmethod def _upload_build_context( cfg: EcsFargateConfig, environment_name: str, nonce: str, ) -> str: """ZIP the environment dir and upload to S3. Returns the S3 key.""" boto3, *_ = _require_aws_sdks() env_dir = Path(cfg.environment_dir or ".") buf = io.BytesIO() with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: for item in env_dir.rglob("*"): if item.is_file(): zf.write(item, arcname=str(item.relative_to(env_dir))) buf.seek(0) s3 = boto3.client("s3", region_name=cfg.region) s3_prefix = cfg.s3_prefix or "ecs-sandbox" s3_key = f"{s3_prefix}/codebuild/{environment_name}-{nonce}.zip" s3.put_object(Bucket=cfg.s3_bucket, Key=s3_key, Body=buf.read()) return s3_key @staticmethod def _resolve_codebuild_project( cfg: EcsFargateConfig, cb: Any, nonce: str, ) -> str: """Return an existing or freshly created CodeBuild project name.""" _, _, ClientError = _require_aws_sdks() if cfg.codebuild_project: return cfg.codebuild_project if not cfg.codebuild_service_role: raise RuntimeError( "codebuild_project or codebuild_service_role is required" ) project_name = f"ecs-sandbox-build-{nonce}" try: cb.create_project( name=project_name, source={"type": "NO_SOURCE", "buildspec": "version: 0.2"}, artifacts={"type": "NO_ARTIFACTS"}, environment={ "type": "LINUX_CONTAINER", "image": "aws/codebuild/amazonlinux-x86_64-standard:5.0", "computeType": cfg.codebuild_compute_type, "privilegedMode": True, }, serviceRole=cfg.codebuild_service_role, timeoutInMinutes=cfg.codebuild_build_timeout, ) except ClientError as e: if "already exists" not in str(e).lower(): raise return project_name @staticmethod def _generate_buildspec( cfg: EcsFargateConfig, repo_name: str, tag: str, image_url: str, ) -> str: """Return an inline CodeBuild buildspec YAML string.""" ecr_registry = (cfg.ecr_repository or "").split("/")[0] pre_build_cmds = [ f"aws ecr get-login-password --region $AWS_DEFAULT_REGION" f" | docker login --username AWS --password-stdin {ecr_registry}", ] if cfg.dockerhub_secret_arn: pre_build_cmds.append( f"DOCKERHUB_CREDS=$(aws secretsmanager get-secret-value" f" --secret-id {cfg.dockerhub_secret_arn}" f" --query SecretString --output text --region $AWS_DEFAULT_REGION)" f' && echo "$DOCKERHUB_CREDS" | python3 -c' """ "import sys,json;c=json.load(sys.stdin);print(c['password'])" """ f'| docker login -u $(echo "$DOCKERHUB_CREDS" | python3 -c' """ "import sys,json;print(json.load(sys.stdin)['username'])") """ f"--password-stdin" ) pre_yaml = "\n".join(f" - {c}" for c in pre_build_cmds) build_cmd = ( f"for i in 1 2 3; do docker build -t {repo_name}:{tag} . && break; " f'echo "build failed ($i/3), retry in 30s"; sleep 30; done' ) return ( "version: 0.2\nphases:\n pre_build:\n commands:\n" f"{pre_yaml}\n build:\n commands:\n" f" - {build_cmd}\n - docker tag {repo_name}:{tag} {image_url}\n" f" post_build:\n commands:\n - docker push {image_url}\n" ) @staticmethod def _poll_codebuild(cb: Any, build_id: str, image_url: str) -> None: """Block until CodeBuild finishes; raise on failure.""" while True: time.sleep(10) status_resp = cb.batch_get_builds(ids=[build_id]) build = status_resp["builds"][0] status = build["buildStatus"] if status == "SUCCEEDED": log.info(f"CodeBuild succeeded: {build_id}") return if status in ("FAILED", "FAULT", "STOPPED", "TIMED_OUT"): phases = build.get("phases", []) failed = [ p for p in phases if p.get("phaseStatus") not in (None, "SUCCEEDED") ] ctx = ( "; ".join( f"{p['phaseType']}: {p.get('phaseStatus')}" for p in failed ) or status ) raise RuntimeError( f"CodeBuild failed for {image_url}: {ctx} (build: {build_id})" ) log.debug( f"CodeBuild {build_id} — phase={build.get('currentPhase')} status={status}" ) @classmethod def _build_and_push( cls, *, cfg: EcsFargateConfig, environment_name: str, tag: str, image_url: str, ) -> None: boto3, *_ = _require_aws_sdks() ecr_repo = cfg.ecr_repository or "" repo_name = ecr_repo.split("/", 1)[1] if "/" in ecr_repo else ecr_repo nonce = uuid.uuid4().hex[:8] log.info(f"Building image via CodeBuild: {image_url}") s3_key = cls._upload_build_context(cfg, environment_name, nonce) cb = boto3.client("codebuild", region_name=cfg.region) project_name = cls._resolve_codebuild_project(cfg, cb, nonce) buildspec = cls._generate_buildspec(cfg, repo_name, tag, image_url) resp = cb.start_build( projectName=project_name, sourceTypeOverride="S3", sourceLocationOverride=f"{cfg.s3_bucket}/{s3_key}", buildspecOverride=buildspec, timeoutInMinutesOverride=cfg.codebuild_build_timeout, privilegedModeOverride=True, environmentTypeOverride="LINUX_CONTAINER", imageOverride="aws/codebuild/amazonlinux-x86_64-standard:5.0", computeTypeOverride=cfg.codebuild_compute_type, ) build_id = resp["build"]["id"] log.info(f"CodeBuild started: {build_id}") cls._poll_codebuild(cb, build_id, image_url)
# ===================================================================== # Core sandbox — ECS task lifecycle + SSH connectivity # ===================================================================== _active_sandboxes: dict[int, Any] = {} _cleanup_lock = threading.RLock() _atexit_registered = False _PROCESS_NONCE = f"{int(time.time())}-{uuid.uuid4().hex[:8]}" _exec_server_url_cache: dict[str, str] = {} def _emergency_cleanup() -> None: with _cleanup_lock: for sb in list(_active_sandboxes.values()): try: sb.stop() except Exception: log.debug( f"Emergency cleanup failed for sandbox {id(sb)}", exc_info=True )
[docs] class EcsFargateSandbox: """ECS Fargate sandbox implementing the :class:`Sandbox` protocol. Supports two modes (determined by ``ssh_sidecar.exec_server_port``): **Exec-server mode** (``exec_server_port`` is set): One-way SSH tunnel + embedded HTTP exec server. ``exec()``, ``upload()``, ``download()`` work through the exec server. **Agent-server mode** (``exec_server_port`` is ``None``): Two-way SSH tunnel. Consumer accesses the hosted agent via ``ssh_tunnel`` / ``local_port``. ``exec()`` / ``upload()`` / ``download()`` raise :class:`RuntimeError`. """ def __init__(self, cfg: EcsFargateConfig, *, task_id: str, run_id: str) -> None: self._cfg = cfg self._task_id = task_id self._run_id = run_id self._task_arn: str | None = None self._task_def_arn: str | None = None self._task_ip: str | None = None self._ssh_key_file: str | None = None self._ssh_tunnel: SshTunnel | None = None self._exec_client: ExecClient | None = None self._started = False self._stopped = False self._ecs: Any = None self._ec2: Any = None # For agent-server mode (two-way tunnel) self._ssh_tunnel_port: int | None = None self._agent_forward_port: int | None = None self._outside_endpoints: list[OutsideEndpoint] = [] # Public API ------------------------------------------------------- @property def task_arn(self) -> str | None: return self._task_arn @property def task_ip(self) -> str | None: return self._task_ip @property def local_port(self) -> int | None: """Local port of the SSH forward tunnel (exec server or agent server).""" if self._ssh_tunnel: try: return self._ssh_tunnel.local_port except RuntimeError: pass return None @property def ssh_tunnel(self) -> SshTunnel | None: return self._ssh_tunnel @property def exec_client(self) -> ExecClient | None: """The exec client (only available in exec-server mode after start).""" return self._exec_client @property def model_tunnel_port(self) -> int | None: """Port the container uses to reach the model (agent-server mode only).""" return self._ssh_tunnel_port @property def is_running(self) -> bool: return self._started and not self._stopped
[docs] def resolve_outside_endpoint(self, url: str) -> str: """Return the URL that processes inside this sandbox should use to reach the outside service at *url* (orchestrator-side). Remaps *url* to the reverse-tunnel address (``127.0.0.1:<tunnel-port>``). Must be called after :meth:`start`. """ if self._ssh_tunnel_port is None: raise RuntimeError( "resolve_outside_endpoint() requires start() to have been called " "with an agent-server SSH sidecar configured." ) parsed = urlparse(url) return parsed._replace(netloc=f"127.0.0.1:{self._ssh_tunnel_port}").geturl()
[docs] def reconnect_tunnel(self) -> None: """Re-open the SSH tunnel if it died (e.g. after a network blip).""" if self._stopped or not self._started: raise RuntimeError("Cannot reconnect tunnel on a stopped/unstarted sandbox") sidecar = self._cfg.ssh_sidecar if sidecar is None: return if self._ssh_tunnel: self._ssh_tunnel.close() self._ssh_tunnel = None self._open_tunnel(sidecar)
[docs] def start( self, *, force_build: bool = False, outside_endpoints: list[OutsideEndpoint] | None = None, ) -> None: if self._started: return self._outside_endpoints = outside_endpoints or [] if len(self._outside_endpoints) > 1: raise ValueError( f"Only one OutsideEndpoint is supported (got {len(self._outside_endpoints)}). " f"The SSH reverse tunnel can only target a single host:port." ) try: self._do_start(force_build=force_build) self._started = True self._register_for_cleanup() except Exception: self._cleanup() raise
[docs] def describe_task(self) -> dict[str, Any] | None: """Return a summary dict of the ECS task's current state, or None.""" if self._task_arn is None or self._ecs is None: return None try: resp = self._ecs.describe_tasks( cluster=self._cfg.cluster, tasks=[self._task_arn] ) tasks = resp.get("tasks") or [] if not tasks: return {"taskArn": self._task_arn} t = tasks[0] return { "taskArn": t.get("taskArn"), "lastStatus": t.get("lastStatus"), "desiredStatus": t.get("desiredStatus"), "stopCode": t.get("stopCode"), "stoppedReason": t.get("stoppedReason"), } except Exception as e: return {"taskArn": self._task_arn, "error": str(e)}
[docs] def stop(self) -> None: if self._stopped: return self._stopped = True self._cleanup() self._unregister_from_cleanup()
[docs] def exec(self, command: str, timeout_sec: float = 180) -> ExecResult: self._require_exec_client() return self._exec_client.exec(command, timeout=int(timeout_sec)) # type: ignore[union-attr]
[docs] def upload(self, local_path: Path, remote_path: str) -> None: self._require_exec_client() local = Path(local_path) if local.is_dir(): for child in local.rglob("*"): if child.is_file(): rel = child.relative_to(local) self.upload(child, f"{remote_path}/{rel}") return if local.stat().st_size > 512 * 1024 and self._cfg.s3_bucket: self._upload_via_s3([local], os.path.dirname(remote_path) or "/tmp") else: self._exec_client.upload(remote_path, local) # type: ignore[union-attr]
[docs] def download(self, remote_path: str, local_path: Path) -> None: self._require_exec_client() data = self._exec_client.download(remote_path) # type: ignore[union-attr] dest = Path(local_path) dest.parent.mkdir(parents=True, exist_ok=True) dest.write_bytes(data)
def __enter__(self) -> EcsFargateSandbox: return self def __exit__(self, *exc: object) -> None: self.stop() # Start implementation -------------------------------------------- def _do_start(self, *, force_build: bool) -> None: cfg = self._cfg sidecar = cfg.ssh_sidecar if sidecar is None: raise ValueError("ssh_sidecar must be configured") # 0. Init boto3 clients self._init_aws_clients() # 1. Docker build (if configured) built_image: str | None = None if cfg.ecr_repository and cfg.environment_dir: built_image = ImageBuilder.ensure_image_built( cfg=cfg, environment_name=_sanitize_id(self._task_id), force_build=force_build, ) # 2. Resolve image image = self._resolve_image(built_image) # 3. Download SSH keys (orchestrator-side, no extra IAM needed on execution role) if not sidecar.private_key_secret_arn: raise ValueError( "ssh_sidecar.private_key_secret_arn is required (pre-provisioned keys only)" ) if not sidecar.public_key_secret_arn: raise ValueError( "ssh_sidecar.public_key_secret_arn is required (pre-provisioned keys only)" ) self._ssh_key_file = download_secret_to_file( sidecar.private_key_secret_arn, cfg.region ) ssh_public_key_value = download_secret_to_string( sidecar.public_key_secret_arn, cfg.region ) # 4. Resolve tunnel port for agent-server mode has_exec_server = sidecar.exec_server_port is not None if not has_exec_server: self._ssh_tunnel_port = self._resolve_ssh_tunnel_port() # 5. Upload exec server to S3 (exec-server mode only) exec_server_url: str | None = None if has_exec_server: exec_server_url = self._upload_exec_server() # 6. Build container command command = self._build_container_command(exec_server_url, sidecar) # 7. Build environment variables env = self._build_env_vars() # 8. Build sidecar container log_region = cfg.region or os.environ.get("AWS_DEFAULT_REGION", "us-east-1") sidecar_def = build_ssh_sidecar_container( sidecar, public_key_value=ssh_public_key_value, max_lifetime_sec=cfg.max_task_lifetime_sec, log_group=cfg.log_group, log_region=log_region, log_stream_prefix=cfg.log_stream_prefix or "ecs-sandbox", ) # 9. Register task definition self._task_def_arn = self._register_task_definition( image=image, command=command, env=env, sidecar_def=sidecar_def, ) # 10. Run task self._task_arn = self._run_task(self._task_def_arn) # 11. Wait for RUNNING self._wait_for_running() # 12. Get task IP self._task_ip = self._get_task_public_ip() # 13. Wait for SSH ready self._wait_for_ssh_ready( self._task_ip, sidecar.sshd_port, sidecar.ssh_ready_timeout_sec ) # 14. Open SSH tunnel self._open_tunnel(sidecar) # 15. Wait for readiness if has_exec_server: health_url = f"http://127.0.0.1:{self._ssh_tunnel.local_port}/health" # type: ignore[union-attr] self._ssh_tunnel.wait_ready( health_url=health_url, timeout=sidecar.ssh_ready_timeout_sec ) # type: ignore[union-attr] self._exec_client = ExecClient(port=self._ssh_tunnel.local_port) # type: ignore[union-attr] # Internal methods ------------------------------------------------ def _init_aws_clients(self) -> None: boto3, Config, _ = _require_aws_sdks() boto_cfg = Config( connect_timeout=30, read_timeout=60, retries={"max_attempts": 8, "mode": "adaptive"}, ) self._ecs = boto3.client("ecs", region_name=self._cfg.region, config=boto_cfg) self._ec2 = boto3.client("ec2", region_name=self._cfg.region, config=boto_cfg) def _resolve_image(self, built_image: str | None = None) -> str: if built_image: return built_image cfg = self._cfg if cfg.image_template: sanitized = _sanitize_id(self._task_id) return cfg.image_template.format( task_id=self._task_id, task_id_sanitized=sanitized ) if not cfg.task_definition: raise ValueError( "No image available: set image_template, ecr_repository + " "environment_dir, or task_definition" ) return "" def _parse_outside_url(self) -> ParseResult: if not self._outside_endpoints: raise ValueError( "Agent-server mode requires at least one OutsideEndpoint " "passed to start(outside_endpoints=...)" ) return urlparse(self._outside_endpoints[0].url) @staticmethod def _port_from_parsed(parsed: ParseResult) -> int: return parsed.port or (443 if parsed.scheme == "https" else 80) def _resolve_ssh_tunnel_port(self) -> int: return self._port_from_parsed(self._parse_outside_url()) def _resolve_tunnel_target(self) -> tuple[str, int]: parsed = self._parse_outside_url() host = parsed.hostname if not host: raise ValueError( f"Cannot resolve hostname from OutsideEndpoint URL: " f"{self._outside_endpoints[0].url}" ) return host, self._port_from_parsed(parsed) def _upload_exec_server(self) -> str: cfg = self._cfg if not cfg.s3_bucket: raise ValueError("s3_bucket is required for exec server upload") cache_key = f"{cfg.s3_bucket}/{self._run_id}" if cache_key in _exec_server_url_cache: return _exec_server_url_cache[cache_key] boto3, *_ = _require_aws_sdks() s3 = boto3.client("s3", region_name=cfg.region) prefix = cfg.s3_prefix or "ecs-sandbox" key = f"{prefix}/{self._run_id}-{_PROCESS_NONCE}/_exec_server/exec_server.py" s3.put_object(Bucket=cfg.s3_bucket, Key=key, Body=EXEC_SERVER_SCRIPT.encode()) url = s3.generate_presigned_url( "get_object", Params={"Bucket": cfg.s3_bucket, "Key": key}, ExpiresIn=21600, ) _exec_server_url_cache[cache_key] = url log.info(f"Uploaded exec server → s3://{cfg.s3_bucket}/{key}") return url def _build_container_command( self, exec_server_url: str | None, sidecar: SshSidecarConfig, ) -> list[str] | None: """Return the main container entrypoint, or None to keep the image default.""" if exec_server_url is None: return None # agent-server mode — image has its own CMD exec_port = sidecar.exec_server_port or 19542 ttl = self._cfg.max_task_lifetime_sec hostname = re.sub(r"[^A-Za-z0-9._-]", "-", self._task_id)[:63] bootstrap = ( "if command -v python >/dev/null 2>&1; then :; " "elif command -v python3 >/dev/null 2>&1; then " ' P=$(command -v python3); ln -sf "$P" /usr/local/bin/python 2>/dev/null || true; ' ' ln -sf "$P" /usr/bin/python 2>/dev/null || true; fi; ' ) setup = ( f"{bootstrap}" f"hostname {shlex.quote(hostname)} 2>/dev/null || true; " f"python3 -c 'import urllib.request as u,sys;" f'u.urlretrieve(sys.argv[1],"/tmp/_exec_server.py")\' ' f"{shlex.quote(exec_server_url)} && " f"TB_EXEC_PORT={exec_port} TB_EXEC_BIND=127.0.0.1 " f"nohup python3 /tmp/_exec_server.py >/tmp/_exec.log 2>&1 & " f"sleep {ttl}" ) return ["sh", "-lc", setup] def _build_env_vars(self) -> dict[str, str]: env: dict[str, str] = {} cfg = self._cfg if cfg.extra_env: for k, v in cfg.extra_env.items(): env[k] = self._render_env_value(v) if self._ssh_tunnel_port and self._outside_endpoints: ep = self._outside_endpoints[0] scheme = urlparse(ep.url).scheme or "http" env[ep.env_var] = f"{scheme}://127.0.0.1:{self._ssh_tunnel_port}" return env def _render_env_value(self, value: str) -> str: if self._ssh_tunnel_port is not None: value = value.replace("{ssh_tunnel_port}", str(self._ssh_tunnel_port)) if self._task_ip: value = value.replace("{task_ip}", self._task_ip) value = value.replace("{task_id}", self._task_id) return value def _register_task_definition( self, *, image: str, command: list[str] | None, env: dict[str, str], sidecar_def: dict[str, Any], ) -> str: cfg = self._cfg log_region = cfg.region or os.environ.get("AWS_DEFAULT_REGION", "us-east-1") log_cfg: dict[str, Any] | None = None if cfg.log_group: log_cfg = { "logDriver": "awslogs", "options": { "awslogs-group": cfg.log_group, "awslogs-region": log_region, "awslogs-stream-prefix": cfg.log_stream_prefix or "ecs-sandbox", "awslogs-create-group": "true", }, } # Try cloning from base task definition _, _, ClientError = _require_aws_sdks() base: dict[str, Any] | None = None if cfg.task_definition: try: resp = self._ecs.describe_task_definition( taskDefinition=cfg.task_definition ) base = resp["taskDefinition"] except ClientError as exc: code = exc.response.get("Error", {}).get("Code", "") if code in ("ClientException",): log.warning( f"Base task definition {cfg.task_definition} not found, will register from scratch" ) base = None else: raise if base is not None: return self._register_from_base( base=base, image=image, command=command, env=env, sidecar_def=sidecar_def, log_cfg=log_cfg, ) return self._register_from_scratch( image=image, command=command, env=env, sidecar_def=sidecar_def, log_cfg=log_cfg, ) def _register_from_base( self, *, base: dict, image: str, command: list[str] | None, env: dict[str, str], sidecar_def: dict, log_cfg: dict | None, ) -> str: cfg = self._cfg containers = list(base.get("containerDefinitions") or []) target = None for cd in containers: if cd.get("name") == cfg.container_name: target = cd break if target is None: names = [c.get("name") for c in containers] raise RuntimeError( f"Base task-def has no container '{cfg.container_name}'. Available: {names}" ) if image: target["image"] = image if command is not None: target["command"] = command target.pop("entryPoint", None) if env: existing = {e["name"]: e["value"] for e in target.get("environment", [])} existing.update(env) target["environment"] = [ {"name": k, "value": v} for k, v in sorted(existing.items()) ] if log_cfg: target["logConfiguration"] = log_cfg target["dependsOn"] = [{"containerName": "ssh-tunnel", "condition": "HEALTHY"}] containers = [c for c in containers if c.get("name") != "ssh-tunnel"] containers.append(sidecar_def) family = self._make_family_name() payload: dict[str, Any] = { "family": family, "networkMode": base.get("networkMode", "awsvpc"), "requiresCompatibilities": base.get("requiresCompatibilities", ["FARGATE"]), "cpu": str(max(int(base.get("cpu") or "256"), int(cfg.cpu))), "memory": str(max(int(base.get("memory") or "512"), int(cfg.memory))), "containerDefinitions": containers, } eph = max( (base.get("ephemeralStorage") or {}).get("sizeInGiB", 20), cfg.ephemeral_storage_gib or 20, ) payload["ephemeralStorage"] = {"sizeInGiB": eph} for k in ("taskRoleArn", "executionRoleArn", "runtimePlatform", "volumes"): if base.get(k) is not None: payload[k] = base[k] if cfg.execution_role_arn: payload["executionRoleArn"] = cfg.execution_role_arn if cfg.task_role_arn: payload["taskRoleArn"] = cfg.task_role_arn return self._do_register(payload) def _register_from_scratch( self, *, image: str, command: list[str] | None, env: dict[str, str], sidecar_def: dict, log_cfg: dict | None, ) -> str: cfg = self._cfg if not cfg.execution_role_arn: raise RuntimeError( "execution_role_arn is required when no base task definition is provided" ) container_def: dict[str, Any] = { "name": cfg.container_name, "essential": True, "dependsOn": [{"containerName": "ssh-tunnel", "condition": "HEALTHY"}], } if image: container_def["image"] = image if command is not None: container_def["command"] = command if cfg.container_port: container_def["portMappings"] = [ {"containerPort": cfg.container_port, "protocol": "tcp"} ] if env: container_def["environment"] = [ {"name": k, "value": v} for k, v in sorted(env.items()) ] if log_cfg: container_def["logConfiguration"] = log_cfg family = self._make_family_name() payload: dict[str, Any] = { "family": family, "networkMode": "awsvpc", "requiresCompatibilities": ["FARGATE"], "cpu": cfg.cpu, "memory": cfg.memory, "executionRoleArn": cfg.execution_role_arn, "containerDefinitions": [container_def, sidecar_def], } if cfg.task_role_arn: payload["taskRoleArn"] = cfg.task_role_arn if cfg.ephemeral_storage_gib: payload["ephemeralStorage"] = {"sizeInGiB": cfg.ephemeral_storage_gib} return self._do_register(payload) def _do_register(self, payload: dict[str, Any]) -> str: resp = _retry_with_backoff( lambda: self._ecs.register_task_definition(**payload), operation_name="register_task_definition", max_retries=25, ) arn = resp["taskDefinition"]["taskDefinitionArn"] log.info(f"Registered task def: {arn}") return arn def _make_family_name(self) -> str: raw = f"{self._cfg.task_definition_family_prefix}-{_sanitize_id(self._task_id)}-{int(time.time())}" family = re.sub(r"[^A-Za-z0-9_-]", "_", raw)[:255] if not family or not re.match(r"^[A-Za-z0-9]", family): family = f"ecs_{family}" return family def _run_task(self, task_def_arn: str) -> str: cfg = self._cfg run_kwargs: dict[str, Any] = { "cluster": cfg.cluster, "taskDefinition": task_def_arn, "launchType": "FARGATE", "networkConfiguration": { "awsvpcConfiguration": { "subnets": cfg.subnets, "securityGroups": cfg.security_groups, "assignPublicIp": "ENABLED" if cfg.assign_public_ip else "DISABLED", } }, } if cfg.platform_version: run_kwargs["platformVersion"] = cfg.platform_version last_failures: Any = None for attempt in range(1, cfg.run_task_max_retries + 1): try: resp = _retry_with_backoff( lambda: self._ecs.run_task(**run_kwargs), operation_name="run_task", max_retries=3, ) except Exception as exc: if not _is_retryable_error(exc) or attempt >= cfg.run_task_max_retries: raise delay = min(60.0, 2.0 ** min(6, attempt - 1)) + random.random() * 2 log.warning( f"run_task failed ({attempt}/{cfg.run_task_max_retries}): {exc} — retry in {delay:.1f}s" ) time.sleep(delay) continue failures = resp.get("failures") or [] if not failures: tasks = resp.get("tasks") or [] if not tasks: raise RuntimeError("run_task returned no tasks") task_arn = tasks[0]["taskArn"] log.info(f"Started ECS task: {task_arn}") return task_arn last_failures = failures reasons = " | ".join(str(f.get("reason", "")) for f in failures) if ( not any(m in reasons.lower() for m in _RETRYABLE_MESSAGES) or attempt >= cfg.run_task_max_retries ): raise RuntimeError(f"run_task failures: {failures}") delay = min(60.0, 2.0 ** min(6, attempt - 1)) + random.random() * 2 log.warning( f"run_task capacity issue ({attempt}/{cfg.run_task_max_retries}): {reasons} — retry in {delay:.1f}s" ) time.sleep(delay) raise RuntimeError( f"run_task failed after {cfg.run_task_max_retries} retries: {last_failures}" ) def _wait_for_running(self) -> None: cfg = self._cfg start = time.monotonic() poll = 5.0 last_status = "" while True: elapsed = time.monotonic() - start if elapsed > cfg.startup_timeout_sec: raise TimeoutError( f"ECS task not RUNNING after {elapsed:.0f}s (last: {last_status})" ) try: resp = self._ecs.describe_tasks( cluster=cfg.cluster, tasks=[self._task_arn] ) except Exception as exc: if _is_retryable_error(exc): time.sleep(poll + random.random() * 3) continue raise tasks = resp.get("tasks") or [] if not tasks: raise RuntimeError("ECS task disappeared") status = tasks[0].get("lastStatus", "UNKNOWN") if status == "RUNNING": log.info(f"ECS task RUNNING after {elapsed:.0f}s") return if status == "STOPPED": raise RuntimeError(f"ECS task stopped: {tasks[0].get('stoppedReason')}") if status != last_status: log.info(f"ECS task {status} ({elapsed:.0f}s)") last_status = status time.sleep(poll + random.random() * 3) poll = min(15.0, poll + 0.5) def _get_task_public_ip(self) -> str: """Resolve the task's public IP from its ENI.""" max_retries = 10 for attempt in range(1, max_retries + 1): try: resp = self._ecs.describe_tasks( cluster=self._cfg.cluster, tasks=[self._task_arn] ) tasks = resp.get("tasks") or [] if not tasks: raise RuntimeError("Task not found") eni_id = None for att in tasks[0].get("attachments") or []: if att.get("type") == "ElasticNetworkInterface": for d in att.get("details") or []: if d.get("name") == "networkInterfaceId": eni_id = d["value"] break if eni_id: break if not eni_id: for att in tasks[0].get("attachments") or []: for d in att.get("details") or []: if d.get("name") == "privateIPv4Address" and d.get("value"): return d["value"] raise RuntimeError("No ENI/IP yet") eni = self._ec2.describe_network_interfaces( NetworkInterfaceIds=[eni_id] ) iface = eni["NetworkInterfaces"][0] pub = (iface.get("Association") or {}).get("PublicIp") if pub: log.info(f"Container public IP: {pub}") return pub priv = iface.get("PrivateIpAddress") if priv: log.info(f"Container private IP: {priv}") return priv raise RuntimeError(f"ENI {eni_id} has no IP") except Exception as exc: if attempt >= max_retries: raise if _is_retryable_error(exc): time.sleep(min(15.0, 2.0**attempt + random.random())) else: log.warning(f"get_task_ip attempt {attempt}/{max_retries}: {exc}") time.sleep(min(15.0, 3.0 + attempt * 2)) raise RuntimeError("get_task_ip exhausted retries") @staticmethod def _wait_for_ssh_ready(host: str, port: int, timeout: float) -> None: deadline = time.monotonic() + timeout log.info(f"Waiting for SSH at {host}:{port}") while time.monotonic() < deadline: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.settimeout(5.0) s.connect((host, port)) s.settimeout(5.0) data = s.recv(256) if data and b"SSH" in data: log.info(f"SSH ready at {host}:{port}") return except OSError: pass time.sleep(2.0) raise TimeoutError(f"SSH not ready at {host}:{port} after {timeout:.0f}s") def _open_tunnel(self, sidecar: SshSidecarConfig) -> None: assert self._task_ip is not None assert self._ssh_key_file is not None has_exec_server = sidecar.exec_server_port is not None if has_exec_server: self._ssh_tunnel = SshTunnel( host=self._task_ip, port=sidecar.sshd_port, user="root", key_file=self._ssh_key_file, forward_port=sidecar.exec_server_port, ) self._ssh_tunnel.open() else: remote_host, remote_port = self._resolve_tunnel_target() local_port = self._ssh_tunnel_port assert local_port is not None self._agent_forward_port = _free_port() container_port = self._cfg.container_port if not container_port: raise ValueError("container_port is required in agent-server mode") forwards = [f"{self._agent_forward_port}:localhost:{container_port}"] reverses = [f"{local_port}:{remote_host}:{remote_port}"] self._ssh_tunnel = SshTunnel( host=self._task_ip, port=sidecar.sshd_port, user="root", key_file=self._ssh_key_file, forwards=forwards, reverses=reverses, local_port_override=self._agent_forward_port, ) self._ssh_tunnel.open() # Cleanup --------------------------------------------------------- def _cleanup(self) -> None: if self._ssh_tunnel: try: self._ssh_tunnel.close() except Exception: log.debug("Failed to close SSH tunnel", exc_info=True) self._ssh_tunnel = None if self._task_arn and self._ecs: try: _retry_with_backoff( lambda: self._ecs.stop_task( cluster=self._cfg.cluster, task=self._task_arn, reason="sandbox cleanup", ), operation_name="stop_task", max_retries=10, ) log.info(f"Stopped ECS task: {self._task_arn}") except Exception as exc: log.warning(f"Failed to stop task {self._task_arn}: {exc}") if self._task_def_arn and self._ecs: try: _retry_with_backoff( lambda: self._ecs.deregister_task_definition( taskDefinition=self._task_def_arn ), operation_name="deregister_task_definition", max_retries=5, ) log.info(f"Deregistered task def: {self._task_def_arn}") except Exception as exc: log.warning( f"Failed to deregister task def {self._task_def_arn}: {exc}" ) if self._ssh_key_file: try: os.remove(self._ssh_key_file) except Exception: log.debug( f"Failed to remove SSH key file {self._ssh_key_file}", exc_info=True ) self._ssh_key_file = None def _require_exec_client(self) -> None: if self._exec_client is None: raise RuntimeError( "exec()/upload()/download() require ssh_sidecar.exec_server_port to be set " "(exec-server mode). In agent-server mode, use sandbox.ssh_tunnel directly." ) def _upload_via_s3(self, paths: list[Path], dest_dir: str) -> None: cfg = self._cfg if not cfg.s3_bucket: raise ValueError("s3_bucket is required for S3 staging") buf = io.BytesIO() with tarfile.open(fileobj=buf, mode="w:gz") as tar: for p in paths: if p.is_file(): tar.add(str(p), arcname=p.name) elif p.is_dir(): for child in p.rglob("*"): if child.is_file(): tar.add(str(child), arcname=str(child.relative_to(p))) buf.seek(0) payload = buf.read() boto3, *_ = _require_aws_sdks() s3 = boto3.client("s3", region_name=cfg.region) prefix = cfg.s3_prefix or "ecs-sandbox" nonce = uuid.uuid4().hex[:12] key = f"{prefix}/{self._run_id}/{self._task_id}/upload-{nonce}.tar.gz" s3.put_object(Bucket=cfg.s3_bucket, Key=key, Body=payload) url = s3.generate_presigned_url( "get_object", Params={"Bucket": cfg.s3_bucket, "Key": key}, ExpiresIn=21600, ) dl_cmd = ( f"mkdir -p {shlex.quote(dest_dir)} && " f"TGZ=/tmp/_upload_$$.tar.gz && " f"( curl -sf -L --max-time 300 -o $TGZ {shlex.quote(url)} 2>/dev/null || " f"python3 -c 'import urllib.request as u,sys;u.urlretrieve(sys.argv[1],sys.argv[2])' " f"{shlex.quote(url)} $TGZ ) && " f"tar xzf $TGZ -C {shlex.quote(dest_dir)} && rm -f $TGZ && echo ok" ) result = self._exec_client.exec(dl_cmd, timeout=360) # type: ignore[union-attr] if "ok" not in result.stdout: raise RuntimeError( f"S3 upload extraction failed (rc={result.return_code}): {result.stderr or result.stdout}" ) # Atexit cleanup -------------------------------------------------- def _register_for_cleanup(self) -> None: global _atexit_registered with _cleanup_lock: _active_sandboxes[id(self)] = self if not _atexit_registered: atexit.register(_emergency_cleanup) _atexit_registered = True def _unregister_from_cleanup(self) -> None: with _cleanup_lock: _active_sandboxes.pop(id(self), None)