Source code for nemo_evaluator.sandbox.ecs_fargate

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import importlib
import io
import logging
import os
import random
import re
import shlex
import subprocess
import tarfile
import time
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import ContextManager, Generator, Iterable

from .base import NemoEvaluatorSandbox, NemoSandboxCommand, NemoSandboxSession

log = logging.getLogger(__name__)


[docs] @dataclass(frozen=True) class EcsFargateConfig: region: str | None cluster: str task_definition: str container_name: str subnets: list[str] security_groups: list[str] assign_public_ip: bool = False # Image selection image_template: str | None = None # supports {task_id} # If true and image_template is provided, register a per-task task definition and deregister on cleanup. register_task_definition_per_task: bool = True # Used only when we need to auto-register a task definition from scratch. cpu: str = "8192" memory: str = "32768" execution_role_arn: str | None = None task_role_arn: str | None = None log_group: str | None = None log_stream_prefix: str = "nemo-evaluator" # Hard TTL for the sandbox task. Container main process will be `sleep <max_task_lifetime_sec>`. max_task_lifetime_sec: int = 180 * 60 # Retries for ecs.run_task placement/capacity failures. run_task_max_retries: int = 30 # File staging (required for ECS sandbox here) s3_bucket: str | None = None s3_prefix: str = "nemo-evaluator" # Minimum timeout for each `aws ecs execute-command` subprocess call. ecs_exec_timeout_sec: int = 180
[docs] class AwsCliMissingError(RuntimeError): pass
[docs] class EcsExecError(RuntimeError): pass
def _require_aws_sdks(): """ Lazily import boto3/botocore only when ECS sandbox is actually used. This avoids requiring AWS dependencies for non-ECS runs. """ try: boto3 = importlib.import_module("boto3") botocore_config = importlib.import_module("botocore.config") botocore_exceptions = importlib.import_module("botocore.exceptions") except ModuleNotFoundError as e: raise RuntimeError( "ECS Fargate sandbox requires AWS SDK dependencies (boto3/botocore).\n" "Install them (e.g. `pip install boto3`) or avoid using the ECS backend." ) from e Config = getattr(botocore_config, "Config") ClientError = getattr(botocore_exceptions, "ClientError") NoCredentialsError = getattr(botocore_exceptions, "NoCredentialsError") PartialCredentialsError = getattr(botocore_exceptions, "PartialCredentialsError") return boto3, Config, ClientError, NoCredentialsError, PartialCredentialsError def _which(name: str) -> str | None: import shutil return shutil.which(name) def _aws_credentials_preflight(region: str | None) -> str: """ Validate that AWS credentials are present and usable. Returns: str: account id (for logging). """ boto3, _Config, ClientError, NoCredentialsError, PartialCredentialsError = ( _require_aws_sdks() ) try: sts = boto3.client("sts", region_name=region) ident = sts.get_caller_identity() return str(ident.get("Account", "unknown")) except (NoCredentialsError, PartialCredentialsError) as e: raise RuntimeError( "AWS credentials not found or incomplete. For ECS sandbox you must provide " "valid AWS credentials to BOTH boto3 and the AWS CLI.\n\n" "Common fixes:\n" "- export AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY (/ AWS_SESSION_TOKEN)\n" "- or configure a profile and set AWS_PROFILE\n" "- or run in an environment with an attached IAM role\n" ) from e except ClientError as e: raise RuntimeError( "AWS credential check failed (sts:GetCallerIdentity). This usually means the credentials " "are invalid/expired or the environment cannot reach AWS STS.\n\n" f"AWS error: {e}\n" ) from e class _EcsExecContainer: """ Minimal docker.Container-like shim used by some agents (installed agents). """ def __init__(self, sandbox: EcsFargateSandbox): self._s = sandbox self.attrs = {"Config": {"User": ""}} def exec_run(self, cmd: list[str], user: str = ""): _ = user # ECS Exec user switching is handled by remote shell / sudo if needed. out = self._s._exec_capture(cmd=cmd, timeout_sec=180.0, check=True) return type("ExecResult", (), {"exit_code": 0, "output": out.encode()}) class EcsFargateTmuxSession(NemoSandboxSession): _TMUX_COMPLETION_COMMAND = "; tmux wait -S done" _LONG_TEXT_THRESHOLD = 2000 def __init__(self, *, session_name: str, sandbox: EcsFargateSandbox): self._session_name = session_name self._sandbox = sandbox self.container = _EcsExecContainer(sandbox) self._previous_buffer: str | None = None self._logger = logging.getLogger(f"{__name__}.EcsFargateTmuxSession") def start(self) -> None: """Ensure the tmux session exists inside the remote container. We use tmux as a lightweight "PTY" so agents can stream incremental output by capturing the pane buffer. The session is created with a large history limit so we can diff across polls. """ self._sandbox._check_tmux_version() if self._has_session(): return self._sandbox._exec_capture( cmd=[ "bash", "-lc", ( f"tmux new-session -x 160 -y 40 -d -s {shlex.quote(self._session_name)} " f"\\; set-option -t {shlex.quote(self._session_name)} history-limit 50000" ), ], timeout_sec=60.0, ) def _has_session(self) -> bool: payload = self._sandbox._exec_capture( cmd=[ "bash", "-lc", ( f"tmux has-session -t {shlex.quote(self._session_name)} 2>/dev/null " "&& echo __NEMO_YES__ || echo __NEMO_NO__" ), ], timeout_sec=90.0, check=True, ) return "__NEMO_YES__" in payload def stop(self) -> None: return def send_keys( self, keys: str | list[str], block: bool = False, min_timeout_sec: float = 0.0, max_timeout_sec: float = 180.0, ) -> None: """Send keystrokes to the tmux session (optionally blocking until completion). Notes: - Large text payloads are pasted via a tmux buffer (staged through S3) to avoid shell/CLI length limits and tmux send-keys slowness. - When `block=True` and the last key is Enter, we inject a `tmux wait` marker so the caller can reliably wait for command completion without polling output. """ if isinstance(keys, str): keys = [keys] special_keys = {"Enter", "C-m", "KPEnter", "C-j", "^M", "^J"} for k in keys: if ( isinstance(k, str) and (len(k) > self._LONG_TEXT_THRESHOLD) and (k not in special_keys) ): self._logger.info( "Large keystroke payload detected (%s chars); using tmux paste-buffer", len(k), ) self._sandbox._tmux_paste_large_text( session_name=self._session_name, text=k, timeout_sec=max(300.0, float(max_timeout_sec) + 60.0), ) keys = [ k for k in keys if not ( isinstance(k, str) and (len(k) > self._LONG_TEXT_THRESHOLD) and (k not in special_keys) ) ] if ( block and keys and keys[-1] in ("Enter", "C-m", "KPEnter", "C-j", "^M", "^J") ): keys = keys[:-1] + [self._TMUX_COMPLETION_COMMAND, "Enter"] if keys: self._sandbox._exec_capture( cmd=["tmux", "send-keys", "-t", self._session_name, *keys], timeout_sec=60.0, ) if block: self._sandbox._exec_capture( cmd=["timeout", f"{max_timeout_sec}s", "tmux", "wait", "done"], timeout_sec=max_timeout_sec + 30.0, ) elif min_timeout_sec > 0: time.sleep(min_timeout_sec) def send_command(self, command: NemoSandboxCommand) -> None: """Send a high-level `NemoSandboxCommand` to the tmux session.""" keys = [command.command, "Enter"] if command.append_enter else [command.command] self.send_keys( keys=keys, block=command.block, min_timeout_sec=command.min_timeout_sec, max_timeout_sec=command.max_timeout_sec, ) def capture_pane(self, capture_entire: bool = False) -> str: """Capture tmux pane output. - `capture_entire=False` returns only the visible screen. - `capture_entire=True` returns the full scrollback (up to history-limit). """ cmd = ["tmux", "capture-pane", "-p"] if capture_entire: cmd.extend(["-S", "-"]) cmd.extend(["-t", self._session_name]) return self._sandbox._exec_capture(cmd=cmd, timeout_sec=60.0) def is_session_alive(self) -> bool: try: return self._has_session() except Exception: return False def get_asciinema_timestamp(self) -> float: return 0.0 def _get_visible_screen(self) -> str: return self.capture_pane(capture_entire=False) def _find_new_content(self, current_buffer: str) -> str | None: """Best-effort diff of tmux scrollback between polls. We keep the previous buffer and try to locate it as a substring of the current scrollback. If found, return only the appended region; otherwise return None and let the caller fall back to showing the visible screen. """ if self._previous_buffer is None: return None pb = self._previous_buffer.strip() if pb and pb in current_buffer: idx = current_buffer.index(pb) if "\n" in pb: start = idx + pb.rfind("\n") + 1 else: start = idx return current_buffer[start:] return None def get_incremental_output(self) -> str: """Return incremental terminal output since the last call. The return value is a human-friendly string prefixed with either: - "New Terminal Output:" if we can find appended lines, or - "Current Terminal Screen:" as a safe fallback. """ current_buffer = self.capture_pane(capture_entire=True) if self._previous_buffer is None: self._previous_buffer = current_buffer return f"Current Terminal Screen:\n{self._get_visible_screen()}" new_content = self._find_new_content(current_buffer) self._previous_buffer = current_buffer if new_content is not None: if new_content.strip(): return f"New Terminal Output:\n{new_content}" return f"Current Terminal Screen:\n{self._get_visible_screen()}" return f"Current Terminal Screen:\n{self._get_visible_screen()}" def copy_to_sandbox( self, paths: list[Path] | Path, container_dir: str | None = None, container_filename: str | None = None, ) -> None: self._sandbox.copy_to_sandbox( paths=paths, container_dir=container_dir, container_filename=container_filename, )
[docs] class EcsFargateSandbox(NemoEvaluatorSandbox): """ Sandbox backed by ECS Fargate + ECS Exec. No inbound connectivity is required. File transfer is done by uploading a tar to S3 and downloading it from inside the container using python stdlib. """ def __init__( self, *, cfg: EcsFargateConfig, task_arn: str, run_id: str, task_id: str, trial_name: str, ): if _which("aws") is None: raise AwsCliMissingError( "AWS CLI ('aws') not found. ECS sandbox requires AWS CLI + session-manager-plugin." ) if _which("session-manager-plugin") is None: raise AwsCliMissingError( "session-manager-plugin not found. ECS Exec requires session-manager-plugin " "to be installed on the harness host (it is invoked by the AWS CLI)." ) self._cfg = cfg self._task_arn = task_arn self._run_id = run_id self._task_id = task_id self._trial_name = trial_name self._sessions: dict[str, EcsFargateTmuxSession] = {} self.container = _EcsExecContainer(self) self._logger = logging.getLogger(f"{__name__}.EcsFargateSandbox") self._tmux_version_checked = False def _check_tmux_version(self) -> None: """Best-effort compatibility warning for tmux output/behavior changes. We rely on tmux for: - `capture-pane` output diffing (incremental output) - `wait`/`wait -S` markers for "block until command finished" tmux does not guarantee backward compatibility of all emitted strings across minor versions. To reduce surprise, we probe `tmux -V` once and emit a loud warning if the version is newer than the known-tested minor + 1. """ if self._tmux_version_checked: return self._tmux_version_checked = True # Bump this when we validate against a newer tmux release. KNOWN_TESTED_TMUX = (3, 4) # (major, minor) max_ok = (KNOWN_TESTED_TMUX[0], KNOWN_TESTED_TMUX[1] + 1) try: out = self._exec_capture( cmd=["sh", "-lc", "tmux -V 2>/dev/null || true"], timeout_sec=30.0, check=False, ).strip() except Exception: # Don't block sandbox startup on a version probe. return m = re.search(r"tmux\s+(\d+)\.(\d+)", out) if not m: return major = int(m.group(1)) minor = int(m.group(2)) if (major, minor) > max_ok or major != KNOWN_TESTED_TMUX[0]: self._logger.warning( "tmux version appears newer than the known-tested range; output parsing or " "magic strings may break across tmux minor releases.\n\n" "Detected: tmux %s.%s (raw=%r)\n" "Known-tested: tmux %s.%s\n" "Allowed (safe-ish): up to tmux %s.%s\n", major, minor, out, KNOWN_TESTED_TMUX[0], KNOWN_TESTED_TMUX[1], max_ok[0], max_ok[1], )
[docs] @classmethod def spin_up( cls, *, cfg: EcsFargateConfig, task_id: str, trial_name: str, run_id: str, pre_upload_paths: Iterable[Path] | None = None, upload_dest_dir: str | None = None, ) -> ContextManager[EcsFargateSandbox]: return _spin_up_ecs_fargate_sandbox( cfg=cfg, task_id=task_id, trial_name=trial_name, run_id=run_id, pre_upload_paths=pre_upload_paths, upload_dest_dir=upload_dest_dir, )
def _aws_ecs_execute( self, *, command: str, timeout_sec: float ) -> subprocess.CompletedProcess: """Invoke `aws ecs execute-command` for this task/container. This uses the AWS CLI (and `session-manager-plugin`) on the harness host to run a remote shell command inside the running Fargate task. """ args = [ "aws", *(["--region", self._cfg.region] if self._cfg.region else []), "ecs", "execute-command", "--cluster", self._cfg.cluster, "--task", self._task_arn, "--container", self._cfg.container_name, "--interactive", "--command", command, ] env = os.environ.copy() env.setdefault("AWS_RETRY_MODE", "standard") env.setdefault("AWS_MAX_ATTEMPTS", "12") effective_timeout = max( float(timeout_sec), float(self._cfg.ecs_exec_timeout_sec) ) return subprocess.run( args, check=False, capture_output=True, text=False, timeout=effective_timeout, env=env, ) def _aws_ecs_execute_with_retry( self, *, command: str, timeout_sec: float ) -> subprocess.CompletedProcess: """Run ECS Exec with retries for common transient failures. Retries cover: - AWS CLI throttling / rate limiting (exponential-ish backoff) - Exec agent not ready / not yet connected shortly after task start - Occasional CLI timeouts (treated as retryable for a short window) """ effective_timeout = max( float(timeout_sec), float(self._cfg.ecs_exec_timeout_sec) ) start = time.time() throttle_sleep = 1.0 last_cp: subprocess.CompletedProcess | None = None def _as_text(x) -> str: if x is None: return "" if isinstance(x, bytes): return x.decode("utf-8", errors="replace") return str(x) while True: try: cp = self._aws_ecs_execute( command=command, timeout_sec=effective_timeout ) except subprocess.TimeoutExpired as te: self._logger.warning( "ECS Exec timed out after %ss; will retry for a bit", int(effective_timeout), ) cp = subprocess.CompletedProcess( args=te.cmd, returncode=124, stdout=_as_text(getattr(te, "stdout", "")), stderr=_as_text(getattr(te, "stderr", "")) or f"TimeoutExpired: command timed out after {effective_timeout} seconds", ) last_cp = cp combined = ( _as_text(getattr(cp, "stdout", "")) + "\n" + _as_text(getattr(cp, "stderr", "")) ).strip() is_throttled = ( "ThrottlingException" in combined or "TooManyRequestsException" in combined or "Rate exceeded" in combined ) _exec_not_ready_markers = ( "TargetNotConnectedException", "execute command agent isn't running", "execute command agent isn’t running", "execute command was not enabled", "TimeoutExpired", ) is_exec_not_ready = any(m in combined for m in _exec_not_ready_markers) if is_throttled and (time.time() - start) < 600: sleep_sec = min(30.0, throttle_sleep) + random.random() self._logger.warning( "ECS Exec throttled; backing off for %.1fs", sleep_sec ) time.sleep(sleep_sec) throttle_sleep = min(30.0, throttle_sleep * 1.7) continue if is_exec_not_ready and (time.time() - start) < 180: time.sleep(3.0) continue return last_cp def _parse_exec_markers( self, *, cp: subprocess.CompletedProcess, check: bool ) -> str: """Extract payload/rc from our wrapped shell output. `_exec_capture` wraps every remote command as: __NEMO_BEGIN__ <payload> __NEMO_RC__=<rc> This function strips noisy session-manager lines, extracts the payload, and raises a helpful `EcsExecError` on non-zero return codes when `check=True`. """ def _as_text(x) -> str: if x is None: return "" if isinstance(x, bytes): return x.decode("utf-8", errors="replace") return str(x) combined_lines = ( _as_text(getattr(cp, "stdout", "")) + "\n" + _as_text(getattr(cp, "stderr", "")) ).splitlines() filtered_lines: list[str] = [] for line in combined_lines: # AWS CLI / session-manager-plugin emits a few non-payload chatter lines. The # exact casing has been observed to vary across versions, so match loosely. ll = line.lower() if ll.startswith("the session manager plugin was installed successfully"): continue if ll.startswith("starting session with sessionid:"): continue if ll.startswith("exiting session with sessionid:"): continue filtered_lines.append(line) text_out = "\n".join(filtered_lines).strip("\n") begin_idx = None for i, line in enumerate(filtered_lines): if line.strip() == "__NEMO_BEGIN__": begin_idx = i break rc = None rc_line_idx = None for i in range(len(filtered_lines) - 1, -1, -1): line = filtered_lines[i].strip() if line.startswith("__NEMO_RC__="): rc_line_idx = i try: rc = int(line.split("=", 1)[1]) except Exception: rc = None break if ( begin_idx is not None and rc_line_idx is not None and rc_line_idx > begin_idx ): payload_lines = filtered_lines[begin_idx + 1 : rc_line_idx] payload = "\n".join(payload_lines).strip("\n") else: payload = text_out if not check: return payload effective_rc = rc if rc is not None else cp.returncode if effective_rc != 0: err = text_out.strip() if ( "TargetNotConnectedException" in err or "execute command was not enabled" in err or "execute command agent" in err ): raise EcsExecError( "ECS Exec failed. The AWS CLI reports that execute-command is not enabled or the " "exec target is not connected / exec agent is not running.\n\n" "This sandbox runs tasks with enableExecuteCommand=True, so the most common causes are:\n" "- The task is in private subnets without NAT/VPC endpoints to SSM/SSMMessages/EC2Messages\n" "- The task IAM role (taskRoleArn) is missing SSM/SSM Messages permissions required for ECS Exec\n" "- The exec agent is still initializing (retries happen, but it can still time out)\n\n" f"rc={effective_rc}\n" f"OUTPUT:\n{text_out}" ) raise EcsExecError( f"ECS Exec failed: rc={effective_rc}\nOUTPUT:\n{text_out}" ) return payload def _exec_capture( self, *, cmd: list[str], timeout_sec: float, check: bool = True ) -> str: """Execute a command in the remote container and return captured output. Implementation notes: - We run via `sh -lc` to get a predictable shell, and wrap output with markers so we can recover the true exit code even if the AWS CLI returns 0. - If the wrapped command exceeds common CLI length limits, we fall back to staging the script in S3 and executing it from `/tmp` in the container. """ shell = " ".join(shlex.quote(x) for x in cmd) wrapped = ( "printf '__NEMO_BEGIN__\\n'; " f"{shell}; " "rc=$?; " "printf '\\n__NEMO_RC__=%s\\n' \"$rc\"" ) command = f"sh -lc {shlex.quote(wrapped)}" def _as_text(x) -> str: if x is None: return "" if isinstance(x, bytes): return x.decode("utf-8", errors="replace") return str(x) if len(command) > 6000: return self._exec_capture_via_s3_script( shell=shell, timeout_sec=timeout_sec, check=check ) cp = self._aws_ecs_execute_with_retry(command=command, timeout_sec=timeout_sec) out = ( _as_text(getattr(cp, "stdout", "")) + "\n" + _as_text(getattr(cp, "stderr", "")) ).lower() if "command too long" in out: return self._exec_capture_via_s3_script( shell=shell, timeout_sec=timeout_sec, check=check ) return self._parse_exec_markers(cp=cp, check=check) def _exec_capture_via_s3_script( self, *, shell: str, timeout_sec: float, check: bool ) -> str: """Fallback for very long commands: stage a script in S3, then download+run it. This avoids AWS CLI / shell quoting / argument length limits by shipping the command body as a `.sh` file, downloaded inside the container using Python stdlib + a presigned URL. """ boto3, _Config, _ClientError, _NoCredentialsError, _PartialCredentialsError = ( _require_aws_sdks() ) if not self._cfg.s3_bucket: raise RuntimeError( "ECS Exec command exceeded length limits and S3 staging is not configured.\n" "Set s3_bucket to enable long-command fallback." ) script = ( "#!/bin/sh\n" "printf '__NEMO_BEGIN__\\n'\n" f"{shell}\n" "rc=$?\n" "printf '\\n__NEMO_RC__=%s\\n' \"$rc\"\n" ) s3 = boto3.client("s3", region_name=self._cfg.region) key = ( f"{self._cfg.s3_prefix}/{self._run_id}/{self._task_id}/{self._trial_name}/" f"exec/{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}.sh" ) s3.put_object(Bucket=self._cfg.s3_bucket, Key=key, Body=script.encode("utf-8")) url = s3.generate_presigned_url( "get_object", Params={"Bucket": self._cfg.s3_bucket, "Key": key}, ExpiresIn=3600, ) py = ( "import os,urllib.request\n" "url=os.environ['NEMO_URL']\n" "dst=os.environ['NEMO_DST']\n" "with urllib.request.urlopen(url, timeout=180) as r:\n" " data=r.read()\n" "open(dst,'wb').write(data)\n" "print('ok')\n" ) remote = f"/tmp/nemo_exec_{int(time.time() * 1000)}.sh" download_and_run = ( f"PY=python3; command -v python3 >/dev/null 2>&1 || PY=python; " f"NEMO_URL={shlex.quote(url)} NEMO_DST={shlex.quote(remote)} " f"$PY -c {shlex.quote(py)} >/dev/null 2>&1; " f"chmod +x {shlex.quote(remote)} >/dev/null 2>&1 || true; " f"sh {shlex.quote(remote)}; " f"rm -f {shlex.quote(remote)} >/dev/null 2>&1 || true" ) cp = self._aws_ecs_execute_with_retry( command=f"sh -lc {shlex.quote(download_and_run)}", timeout_sec=timeout_sec, ) return self._parse_exec_markers(cp=cp, check=check) def _s3_stage_text(self, *, text: str, suffix: str) -> str: """Upload small text to S3 and return a presigned GET URL.""" boto3, _Config, _ClientError, _NoCredentialsError, _PartialCredentialsError = ( _require_aws_sdks() ) if not self._cfg.s3_bucket: raise RuntimeError( "S3 staging is required for large-text tmux paste fallback. Set s3_bucket." ) s3 = boto3.client("s3", region_name=self._cfg.region) key = ( f"{self._cfg.s3_prefix}/{self._run_id}/{self._task_id}/{self._trial_name}/" f"{suffix}/{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}.txt" ) s3.put_object(Bucket=self._cfg.s3_bucket, Key=key, Body=text.encode("utf-8")) return s3.generate_presigned_url( "get_object", Params={"Bucket": self._cfg.s3_bucket, "Key": key}, ExpiresIn=3600, ) def _tmux_paste_large_text( self, *, session_name: str, text: str, timeout_sec: float ) -> None: """Paste a large text blob into tmux via S3 staging. This is used by `EcsFargateTmuxSession.send_keys` when a keystroke payload is too large for reliable `tmux send-keys`. """ url = self._s3_stage_text(text=text, suffix="tmux-paste") remote = f"/tmp/nemo_paste_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}.txt" py = ( "import os,urllib.request\n" "url=os.environ['NEMO_URL']\n" "dst=os.environ['NEMO_DST']\n" "with urllib.request.urlopen(url, timeout=180) as r:\n" " data=r.read()\n" "open(dst,'wb').write(data)\n" "print('ok')\n" ) self._exec_capture( cmd=[ "bash", "-lc", ( f"PY=python3; command -v python3 >/dev/null 2>&1 || PY=python; " f"NEMO_URL={shlex.quote(url)} NEMO_DST={shlex.quote(remote)} " f"$PY -c {shlex.quote(py)} >/dev/null" ), ], timeout_sec=min(600.0, float(timeout_sec)), ) buf_name = f"nemo_paste_{uuid.uuid4().hex[:8]}" self._exec_capture( cmd=["tmux", "load-buffer", "-b", buf_name, remote], timeout_sec=60.0 ) self._exec_capture( cmd=["tmux", "paste-buffer", "-b", buf_name, "-t", session_name], timeout_sec=60.0, ) self._exec_capture( cmd=["tmux", "delete-buffer", "-b", buf_name], timeout_sec=60.0, check=False ) self._exec_capture(cmd=["rm", "-f", remote], timeout_sec=60.0, check=False)
[docs] def create_session( self, session_name: str, is_active_stream: bool = False, as_configured_user: bool = True, ) -> EcsFargateTmuxSession: """Create (and start) a tmux-backed sandbox session.""" _ = is_active_stream _ = as_configured_user if session_name in self._sessions: raise ValueError(f"Session {session_name} already exists") session = EcsFargateTmuxSession(session_name=session_name, sandbox=self) session.start() self._sessions[session_name] = session return session
[docs] def copy_to_sandbox( self, *, paths: list[Path] | Path, container_dir: str | None = None, container_filename: str | None = None, ) -> None: """Copy local files/dirs into the remote container via S3-staged tarball. Flow: - Tar+gzip the provided paths in-memory on the harness host - Upload to S3 and generate a presigned URL - Download inside the container and extract into `container_dir` Security note: extraction uses tarfile's safety filter on Python 3.12+, and a manual path traversal check on Python 3.10-3.11. """ boto3, _Config, _ClientError, _NoCredentialsError, _PartialCredentialsError = ( _require_aws_sdks() ) if container_dir is None: raise ValueError("container_dir is required") if isinstance(paths, Path): paths = [paths] buf = io.BytesIO() with tarfile.open(fileobj=buf, mode="w:gz") as tar: if container_filename is not None: if len(paths) != 1 or not paths[0].is_file(): raise ValueError( "container_filename requires exactly one file path" ) tar.add(paths[0], arcname=container_filename) else: for p in paths: if p.is_file(): tar.add(p, arcname=p.name) elif p.is_dir(): for item in p.rglob("*"): tar.add(item, arcname=item.relative_to(p)) buf.seek(0) tar_bytes = buf.read() if not self._cfg.s3_bucket: raise RuntimeError( "ECS sandbox requires S3 staging for bundle upload, but no s3_bucket was provided." ) s3 = boto3.client("s3", region_name=self._cfg.region) key = ( f"{self._cfg.s3_prefix}/{self._run_id}/{self._task_id}/{self._trial_name}/" f"{int(time.time() * 1000)}.tar" ) s3.put_object(Bucket=self._cfg.s3_bucket, Key=key, Body=tar_bytes) url = s3.generate_presigned_url( "get_object", Params={"Bucket": self._cfg.s3_bucket, "Key": key}, ExpiresIn=3600, ) py = ( "import os,tarfile,urllib.request,io\n" "url=os.environ['NEMO_URL']\n" "dest=os.environ['NEMO_DEST']\n" "with urllib.request.urlopen(url, timeout=180) as r:\n" " data=r.read()\n" "os.makedirs(dest, exist_ok=True)\n" "buf=io.BytesIO(data)\n" "with tarfile.open(fileobj=buf, mode='r:*') as t:\n" " try:\n" " # Python 3.12+: prevent path traversal via tarfile's built-in filter.\n" " t.extractall(dest, filter='data')\n" " except TypeError:\n" " # Python 3.10-3.11: validate members before extracting.\n" " dest_real=os.path.realpath(dest)\n" " for m in t.getmembers():\n" " name=m.name\n" " if os.path.isabs(name):\n" " raise RuntimeError('Unsafe tar member (absolute path): %r' % (name,))\n" " target=os.path.realpath(os.path.join(dest, name))\n" " if not (target == dest_real or target.startswith(dest_real + os.sep)):\n" " raise RuntimeError('Unsafe tar member (path traversal): %r' % (name,))\n" " t.extractall(dest)\n" "print('ok')\n" ) self._exec_capture( cmd=[ "bash", "-lc", ( f"PY=python3; command -v python3 >/dev/null 2>&1 || PY=python; " f"NEMO_URL={shlex.quote(url)} NEMO_DEST={shlex.quote(container_dir)} " f"$PY -c {shlex.quote(py)}" ), ], timeout_sec=300.0, )
[docs] def stop(self) -> None: """Best-effort teardown for local session objects (remote task is stopped by the contextmanager).""" self._sessions.clear()
@contextmanager def _spin_up_ecs_fargate_sandbox( *, cfg: EcsFargateConfig, task_id: str, trial_name: str, run_id: str, pre_upload_paths: Iterable[Path] | None = None, upload_dest_dir: str | None = None, ) -> Generator[EcsFargateSandbox, None, None]: """Create a short-lived ECS Fargate task and expose it as an `EcsFargateSandbox`. High-level flow: (optional) register per-task task definition (image override) | ecs.run_task (retry on capacity) | wait until RUNNING | yield sandbox | stop_task + (optional) deregister temporary task definition If `pre_upload_paths` and `upload_dest_dir` are provided, files are staged into the container immediately after the task is RUNNING (via `copy_to_sandbox`). """ boto3, Config, ClientError, _NoCredentialsError, _PartialCredentialsError = ( _require_aws_sdks() ) account_id = _aws_credentials_preflight(cfg.region) ecs = boto3.client( "ecs", region_name=cfg.region, config=Config(retries={"max_attempts": 12, "mode": "standard"}), ) bootstrap = ( "if command -v python >/dev/null 2>&1; then :; " "elif command -v python3 >/dev/null 2>&1; then " " P=$(command -v python3); " ' ln -sf "$P" /usr/local/bin/python 2>/dev/null || true; ' ' ln -sf "$P" /usr/bin/python 2>/dev/null || true; ' "fi" ) keepalive_command = [ "sh", "-lc", f"{bootstrap}; sleep {int(cfg.max_task_lifetime_sec)}", ] overrides: dict = {"containerOverrides": [{"name": cfg.container_name}]} overrides["containerOverrides"][0]["command"] = list(keepalive_command) overrides["containerOverrides"][0]["environment"] = [ {"name": "TEST_DIR", "value": "/tests"} ] task_definition_to_run = cfg.task_definition registered_task_definition_arn: str | None = None if cfg.image_template and cfg.register_task_definition_per_task: image = cfg.image_template.format(task_id=task_id) if cfg.log_group: try: logs = boto3.client("logs", region_name=cfg.region) logs.create_log_group(logGroupName=cfg.log_group) except Exception: pass try: base = ecs.describe_task_definition(taskDefinition=cfg.task_definition)[ "taskDefinition" ] except ClientError: base = None raw_family = f"nemo-{run_id}-{task_id}-{trial_name}-{int(time.time() * 1000)}-{uuid.uuid4().hex[:8]}" family = re.sub(r"[^A-Za-z0-9_-]", "_", raw_family) if not family or not re.match(r"^[A-Za-z0-9]", family): family = f"nemo_{family}" family = family[:255] log.info( "Registering per-task ECS task definition family=%s (raw=%s)", family, raw_family, ) if base is not None: container_defs = base.get("containerDefinitions") or [] found = False for cd in container_defs: if cd.get("name") == cfg.container_name: cd["image"] = image cd["command"] = list(keepalive_command) if cfg.log_group: cd["logConfiguration"] = { "logDriver": "awslogs", "options": { "awslogs-group": cfg.log_group, "awslogs-region": cfg.region or "", "awslogs-stream-prefix": cfg.log_stream_prefix, }, } found = True break if not found: raise RuntimeError( "Base task definition does not contain the configured container_name.\n\n" f"container_name: {cfg.container_name}\n" f"task_definition: {cfg.task_definition}\n" f"available_containers: {[c.get('name') for c in container_defs]}\n" ) register_payload: dict = { "family": family, "networkMode": base.get("networkMode", "awsvpc"), "requiresCompatibilities": base.get( "requiresCompatibilities", ["FARGATE"] ), "cpu": base.get("cpu"), "memory": base.get("memory"), "containerDefinitions": container_defs, } for k in ( "taskRoleArn", "executionRoleArn", "ephemeralStorage", "runtimePlatform", "volumes", "placementConstraints", "proxyConfiguration", "pidMode", "ipcMode", "inferenceAccelerators", ): if k in base and base[k] is not None: register_payload[k] = base[k] else: execution_role_arn = cfg.execution_role_arn or "" if not execution_role_arn: raise RuntimeError( "Unable to describe base task definition and no execution role ARN provided.\n\n" "To auto-register per-task Fargate task definitions, provide an execution role ARN.\n" f"task_definition (missing): {cfg.task_definition}\n" ) container_def: dict = { "name": cfg.container_name, "image": image, "essential": True, "command": list(keepalive_command), } if cfg.log_group: container_def["logConfiguration"] = { "logDriver": "awslogs", "options": { "awslogs-group": cfg.log_group, "awslogs-region": cfg.region or "", "awslogs-stream-prefix": cfg.log_stream_prefix, }, } register_payload = { "family": family, "networkMode": "awsvpc", "requiresCompatibilities": ["FARGATE"], "cpu": str(cfg.cpu), "memory": str(cfg.memory), "executionRoleArn": execution_role_arn, "containerDefinitions": [container_def], } if cfg.task_role_arn: register_payload["taskRoleArn"] = cfg.task_role_arn reg = None last_register_error: Exception | None = None for attempt in range(1, 16): try: reg = ecs.register_task_definition(**register_payload) break except ClientError as e: last_register_error = e code = (e.response.get("Error") or {}).get("Code", "") msg = str(e) retryable = ( "Too many concurrent attempts to create a new revision" in msg or "Rate exceeded" in msg or code in {"ThrottlingException", "TooManyRequestsException"} ) if not retryable or attempt >= 15: raise RuntimeError( "Failed to register per-task ECS task definition.\n\n" f"family: {family}\n" f"AWS error: {e}\n" ) from e sleep_sec = ( min(30.0, 0.75 * (2 ** min(6, attempt - 1))) + random.random() ) log.warning( "ECS RegisterTaskDefinition concurrency limit hit; retrying %s/15 in %.1fs " "(family=%s, code=%s)", attempt, sleep_sec, family, code, ) time.sleep(sleep_sec) if reg is None: raise RuntimeError( "Failed to register per-task ECS task definition after retries.\n\n" f"family: {family}\n" f"last_error: {last_register_error}\n" ) registered_task_definition_arn = reg["taskDefinition"]["taskDefinitionArn"] task_definition_to_run = registered_task_definition_arn last_failures = None last_client_error: Exception | None = None for attempt in range(1, max(1, int(cfg.run_task_max_retries)) + 1): try: resp = ecs.run_task( cluster=cfg.cluster, taskDefinition=task_definition_to_run, launchType="FARGATE", platformVersion="LATEST", enableExecuteCommand=True, overrides=overrides, networkConfiguration={ "awsvpcConfiguration": { "subnets": cfg.subnets, "securityGroups": cfg.security_groups, "assignPublicIp": "ENABLED" if cfg.assign_public_ip else "DISABLED", } }, ) failures = resp.get("failures") or [] if failures: last_failures = failures reasons = " | ".join(str(f.get("reason", "")) for f in failures) retryable = "Capacity is unavailable" in reasons if not retryable or attempt >= int(cfg.run_task_max_retries): raise RuntimeError( f"ECS run_task failures for task_id={task_id}: {failures}" ) sleep_sec = min(60.0, (2.0 ** min(6, attempt - 1))) + random.random() log.warning( "ECS capacity unavailable for task_id=%s; retrying %s/%s in %.1fs", task_id, attempt, cfg.run_task_max_retries, sleep_sec, ) time.sleep(sleep_sec) continue tasks = resp.get("tasks") or [] if not tasks: raise RuntimeError("ECS run_task returned no tasks") task_arn = tasks[0]["taskArn"] break except ClientError as e: last_client_error = e msg = str(e) retryable = "Capacity is unavailable" in msg if not retryable or attempt >= int(cfg.run_task_max_retries): raise RuntimeError( "Failed to run ECS task. This is usually caused by missing IAM permissions " "(ecs:RunTask / iam:PassRole), an invalid cluster/task definition, or invalid " "subnets/security groups.\n\n" f"AWS account: {account_id}\n" f"cluster: {cfg.cluster}\n" f"task_definition: {task_definition_to_run}\n" f"container_name: {cfg.container_name}\n" f"task_id: {task_id}\n" f"AWS error: {e}\n" ) from e sleep_sec = min(60.0, (2.0 ** min(6, attempt - 1))) + random.random() log.warning( "ECS run_task capacity error for task_id=%s; retrying %s/%s in %.1fs", task_id, attempt, cfg.run_task_max_retries, sleep_sec, ) time.sleep(sleep_sec) continue else: raise RuntimeError( "Failed to run ECS task after retries.\n\n" f"task_id: {task_id}\n" f"cluster: {cfg.cluster}\n" f"task_definition: {task_definition_to_run}\n" f"last_failures: {last_failures}\n" f"last_client_error: {last_client_error}\n" ) log.info( "Started ECS task: %s (account=%s, task_id=%s)", task_arn, account_id, task_id ) start = time.time() while True: d = ecs.describe_tasks(cluster=cfg.cluster, tasks=[task_arn]) t = (d.get("tasks") or [None])[0] if t is None: raise RuntimeError("ECS task disappeared") status = t.get("lastStatus") if status == "RUNNING": break if status == "STOPPED": raise RuntimeError(f"ECS task stopped early: {t.get('stoppedReason')}") if time.time() - start > 300: raise TimeoutError("Timed out waiting for ECS task to be RUNNING") time.sleep(2.0) sandbox = EcsFargateSandbox( cfg=cfg, task_arn=task_arn, run_id=run_id, task_id=task_id, trial_name=trial_name, ) try: if pre_upload_paths and upload_dest_dir: sandbox.copy_to_sandbox( paths=list(pre_upload_paths), container_dir=upload_dest_dir ) yield sandbox finally: try: ecs.stop_task( cluster=cfg.cluster, task=task_arn, reason="nemo sandbox done" ) except Exception: log.warning("Failed to stop ECS task", exc_info=True) if registered_task_definition_arn is not None: try: ecs.deregister_task_definition( taskDefinition=registered_task_definition_arn ) except Exception: log.warning( "Failed to deregister temporary ECS task definition", exc_info=True )