Source code for nemo_run.core.execution.skypilot_jobs

import logging
import os
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union

from invoke.context import Context

from nemo_run.config import RUNDIR_NAME
from nemo_run.core.execution.base import (
    Executor,
    ExecutorMacros,
)
from nemo_run.core.packaging.base import Packager
from nemo_run.core.packaging.git import GitArchivePackager

_SKYPILOT_AVAILABLE: bool = False
try:
    import sky
    import sky.task as skyt
    from sky import backends

    _SKYPILOT_AVAILABLE = True
except ImportError:
    # suppress import error so we don't crash if skypilot is not installed.
    pass

logger = logging.getLogger(__name__)


[docs] @dataclass(kw_only=True) class SkypilotJobsExecutor(Executor): """ Dataclass to configure a Skypilot Jobs Executor. This executor launches managed jobs and requires the `Skypilot API Server <https://docs.skypilot.co/en/latest/reference/api-server/api-server.html>`. Some familiarity with `Skypilot <https://skypilot.readthedocs.io/en/latest/docs/index.html>`_ is necessary. In order to use this executor, you need to install NeMo Run with either `skypilot` (for only Kubernetes) or `skypilot-all` (for all clouds) optional features. Example: .. code-block:: python skypilot = SkypilotJobsExecutor( gpus="A10G", gpus_per_node=devices, container_image="nvcr.io/nvidia/nemo:dev", infra="k8s/my-context", network_tier="best", cluster_name="nemo_tester", file_mounts={ "nemo_run.whl": "nemo_run.whl", "/workspace/code": "/local/path/to/code", }, storage_mounts={ "/workspace/outputs": { "name": "my-training-outputs", "store": "gcs", # or "s3", "azure", etc. "mode": "MOUNT", "persistent": True, }, "/workspace/checkpoints": { "name": "model-checkpoints", "store": "s3", "mode": "MOUNT", "persistent": True, } }, setup=\"\"\" conda deactivate nvidia-smi ls -al ./ pip install nemo_run.whl cd /opt/NeMo && git pull origin main && pip install . \"\"\", ) """ HEAD_NODE_IP_VAR = "head_node_ip" NPROC_PER_NODE_VAR = "SKYPILOT_NUM_GPUS_PER_NODE" NUM_NODES_VAR = "num_nodes" NODE_RANK_VAR = "SKYPILOT_NODE_RANK" HET_GROUP_HOST_VAR = "het_group_host" container_image: Optional[str] = None cloud: Optional[Union[str, list[str]]] = None region: Optional[Union[str, list[str]]] = None zone: Optional[Union[str, list[str]]] = None gpus: Optional[Union[str, list[str]]] = None gpus_per_node: Optional[int] = None cpus: Optional[Union[int | float, list[int | float]]] = None memory: Optional[Union[int | float, list[int | float]]] = None instance_type: Optional[Union[str, list[str]]] = None num_nodes: int = 1 use_spot: Optional[Union[bool, list[bool]]] = None disk_size: Optional[Union[int, list[int]]] = None disk_tier: Optional[Union[str, list[str]]] = None ports: Optional[tuple[str]] = None file_mounts: Optional[dict[str, str]] = None storage_mounts: Optional[dict[str, dict[str, Any]]] = None cluster_name: Optional[str] = None setup: Optional[str] = None autodown: bool = False idle_minutes_to_autostop: Optional[int] = None torchrun_nproc_per_node: Optional[int] = None cluster_config_overrides: Optional[dict[str, Any]] = None infra: Optional[str] = None network_tier: Optional[str] = None retry_until_up: bool = False packager: Packager = field(default_factory=GitArchivePackager) # type: ignore # noqa: F821 def __post_init__(self): assert _SKYPILOT_AVAILABLE, ( 'Skypilot is not installed. Please install it using `pip install "nemo_run[skypilot]"`.' ) assert isinstance(self.packager, GitArchivePackager), ( "Only GitArchivePackager is currently supported for SkypilotExecutor." ) if self.infra is not None: assert self.cloud is None, "Cannot specify both `infra` and `cloud` parameters." assert self.region is None, "Cannot specify both `infra` and `region` parameters." assert self.zone is None, "Cannot specify both `infra` and `zone` parameters." logger.info( "`cloud` is deprecated and will be removed in a future version. Use `infra` instead." ) @classmethod def parse_app(cls: Type["SkypilotJobsExecutor"], app_id: str) -> tuple[str, str, int]: app = app_id.split("___") cluster, task, job_id = app[0], app[1], app[2] assert cluster and task and job_id, f"Invalid app id for Skypilot: {app_id}" return cluster, task, int(job_id) def to_resources(self) -> Union[set["sky.Resources"], set["sky.Resources"]]: from sky.resources import Resources resources_cfg = {} accelerators = None if self.gpus: if not self.gpus_per_node: self.gpus_per_node = 1 else: assert isinstance(self.gpus_per_node, int) gpus = [self.gpus] if isinstance(self.gpus, str) else self.gpus accelerators = {} for gpu in gpus: accelerators[gpu] = self.gpus_per_node resources_cfg["accelerators"] = accelerators if self.container_image: resources_cfg["image_id"] = self.container_image any_of = [] def parse_attr(attr: str): if getattr(self, attr, None) is not None: value = getattr(self, attr) if isinstance(value, list): for i, val in enumerate(value): if len(any_of) < i + 1: any_of.append({}) if isinstance(val, str) and val.lower() == "none": any_of[i][attr] = None else: any_of[i][attr] = val else: if isinstance(value, str) and value.lower() == "none": resources_cfg[attr] = None else: resources_cfg[attr] = value attrs = [ "cloud", "region", "zone", "cpus", "memory", "instance_type", "use_spot", "infra", "network_tier", "image_id", "disk_size", "disk_tier", "ports", ] for attr in attrs: parse_attr(attr) resources_cfg["any_of"] = any_of if self.cluster_config_overrides: resources_cfg["_cluster_config_overrides"] = self.cluster_config_overrides resources = Resources.from_yaml_config(resources_cfg) return resources # type: ignore @classmethod def status(cls: Type["SkypilotJobsExecutor"], app_id: str) -> Optional[dict]: from sky import stream_and_get import sky.exceptions as sky_exceptions import sky.jobs.client.sdk as sky_jobs _, _, job_id = cls.parse_app(app_id) try: job_details: List[Dict[str, Any]] = stream_and_get( sky_jobs.queue(refresh=True, all_users=True, job_ids=[job_id]), )[0] except sky_exceptions.ClusterNotUpError: return None return job_details @classmethod def cancel(cls: Type["SkypilotJobsExecutor"], app_id: str): from sky.jobs.client.sdk import cancel _, _, job_id = cls.parse_app(app_id=app_id) job_details = cls.status(app_id=app_id) if not job_details: return cancel(job_ids=[job_id]) @classmethod def logs(cls: Type["SkypilotJobsExecutor"], app_id: str, fallback_path: Optional[str]): import sky.jobs.client.sdk as sky_jobs _, _, job_id = cls.parse_app(app_id) job_details = cls.status(app_id) is_terminal = False if job_details and job_details["status"]: is_terminal = job_details["status"].is_terminal() elif not job_details: is_terminal = True if fallback_path and is_terminal: log_path = os.path.expanduser(os.path.join(fallback_path, "run.log")) if os.path.isfile(log_path): with open(os.path.expanduser(os.path.join(fallback_path, "run.log"))) as f: for line in f: print(line, end="", flush=True) return sky_jobs.tail_logs(job_id=job_id) @property def workdir(self) -> str: return os.path.join(f"{self.job_dir}", "workdir") def package_configs(self, *cfgs: tuple[str, str]) -> list[str]: filenames = [] basepath = os.path.join(self.job_dir, "configs") for name, cfg in cfgs: filename = os.path.join(basepath, name) os.makedirs(os.path.dirname(filename), exist_ok=True) with open(filename, "w") as f: f.write(cfg) filenames.append( os.path.join( "/", RUNDIR_NAME, "configs", name, ) ) return filenames
[docs] def assign( self, exp_id: str, exp_dir: str, task_id: str, task_dir: str, ): self.job_name = task_id self.experiment_dir = exp_dir self.job_dir = os.path.join(exp_dir, task_dir) self.experiment_id = exp_id
def package(self, packager: Packager, job_name: str): assert self.experiment_id, "Executor not assigned to an experiment." if isinstance(packager, GitArchivePackager): output = subprocess.run( ["git", "rev-parse", "--show-toplevel"], check=True, stdout=subprocess.PIPE, ) path = output.stdout.splitlines()[0].decode() base_path = Path(path).absolute() else: base_path = Path(os.getcwd()).absolute() local_pkg = packager.package(base_path, self.job_dir, job_name) local_code_extraction_path = os.path.join(self.job_dir, "code") ctx = Context() ctx.run(f"mkdir -p {local_code_extraction_path}") if self.get_launcher().nsys_profile: remote_nsys_extraction_path = os.path.join( self.job_dir, self.get_launcher().nsys_folder ) ctx.run(f"mkdir -p {remote_nsys_extraction_path}") if local_pkg: ctx.run( f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True )
[docs] def nnodes(self) -> int: return self.num_nodes
[docs] def nproc_per_node(self) -> int: if self.torchrun_nproc_per_node: return self.torchrun_nproc_per_node return self.gpus_per_node or 1
[docs] def macro_values(self) -> Optional[ExecutorMacros]: return ExecutorMacros( head_node_ip_var=self.HEAD_NODE_IP_VAR, nproc_per_node_var=self.NPROC_PER_NODE_VAR, num_nodes_var=self.NUM_NODES_VAR, node_rank_var=self.NODE_RANK_VAR, het_group_host_var=self.HET_GROUP_HOST_VAR, )
def to_task( self, name: str, cmd: Optional[list[str]] = None, env_vars: Optional[dict[str, str]] = None, ) -> "skyt.Task": from sky.task import Task run_cmd = None if cmd: run_cmd = f""" conda deactivate num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l` echo "num_nodes=$num_nodes" head_node_ip=`echo "$SKYPILOT_NODE_IPS" | head -n1` echo "head_node_ip=$head_node_ip" cd /nemo_run/code {" ".join(cmd)} """ task = Task( name=name, setup=self.setup if self.setup else "", run=run_cmd, envs=self.env_vars, num_nodes=self.num_nodes, ) # Handle regular file mounts file_mounts = self.file_mounts or {} file_mounts["/nemo_run"] = self.job_dir task.set_file_mounts(file_mounts) # Handle storage mounts separately if self.storage_mounts: from sky.data import Storage storage_objects = {} for mount_path, config in self.storage_mounts.items(): # Create Storage object from config dict storage_obj = Storage.from_yaml_config(config) storage_objects[mount_path] = storage_obj task.set_storage_mounts(storage_objects) task.set_resources(self.to_resources()) if env_vars: task.update_envs(env_vars) return task def launch( self, task: "skyt.Task", num_nodes: Optional[int] = None, ) -> tuple[Optional[int], Optional["backends.ResourceHandle"]]: from sky import stream_and_get from sky.jobs.client.sdk import launch if num_nodes: task.num_nodes = num_nodes job_id, handle = stream_and_get(launch(task)) return job_id, handle def cleanup(self, handle: str): import sky.jobs.client.sdk as sky_jobs _, _, path_str = handle.partition("://") path = path_str.split("/") app_id = path[1] _, _, job_id = self.parse_app(app_id) sky_jobs.download_logs( name=None, job_id=job_id, refresh=True, controller=True, local_dir=self.job_dir, )