Source code for nemo_run.core.execution.slurm

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import os
import shlex
import subprocess
import time
import warnings
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, Type, TypeAlias, Union

import invoke
from invoke.context import Context
from rich.console import Console
from rich.text import Text

from nemo_run.config import RUNDIR_NAME, RUNDIR_SPECIAL_NAME
from nemo_run.core.execution.base import (
    Executor,
    ExecutorMacros,
)
from nemo_run.core.execution.launcher import FaultTolerance, Launcher, SlurmTemplate, Torchrun
from nemo_run.core.execution.utils import fill_template
from nemo_run.core.frontend.console.api import CONSOLE
from nemo_run.core.packaging.base import Packager
from nemo_run.core.packaging.git import GitArchivePackager
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.core.tunnel.client import (
    Callback,
    LocalTunnel,
    PackagingJob,
    SSHConfigFile,
    SSHTunnel,
    Tunnel,
)
from nemo_run.core.tunnel.server import TunnelMetadata, server_dir
from nemo_run.devspace.base import DevSpace

logger = logging.getLogger(__name__)
noquote: TypeAlias = str


@dataclass(kw_only=True)
class SlurmJobDetails:
    """Store details like paths and name related to the slurm job."""

    job_name: Optional[str] = None
    folder: Optional[str] = None
    ray_log_prefix: str = "ray-"

    @property
    def stderr(self) -> Path:
        assert self.folder, self.job_name
        return Path(self.folder) / f"sbatch_{self.job_name}_%j.err"

    @property
    def stdout(self) -> Path:
        assert self.folder, self.job_name
        return Path(self.folder) / f"sbatch_{self.job_name}_%j.out"

    @property
    def srun_stderr(self) -> Path:
        assert self.folder, self.job_name
        return Path(self.folder) / f"log-{self.job_name}_%j_${{SLURM_RESTART_COUNT:-0}}.err"

    @property
    def srun_stdout(self) -> Path:
        assert self.folder, self.job_name
        return Path(self.folder) / f"log-{self.job_name}_%j_${{SLURM_RESTART_COUNT:-0}}.out"

    @property
    def ls_term(self) -> str:
        """This term will be used to fetch the logs.

        The command used to list the files is ls -1 {ls_term} 2> /dev/null
        """
        assert self.folder
        return os.path.join(self.folder, "log*")

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.folder})"


def get_packaging_job_key(experiment_id: str, job_name: str) -> str:
    return f"{experiment_id}:{job_name}"


[docs] @dataclass(kw_only=True) class SlurmExecutor(Executor): """ Dataclass to configure a Slurm Cluster. During execution, sbatch related parameters will automatically get parsed to their corresponding sbatch flags. .. note:: We assume that the underlying Slurm cluster has `Pyxis <https://github.com/NVIDIA/pyxis>`_ enabled. The slurm executor will fail if the slurm cluster doesn't support pyxis. Example: .. code-block:: python def your_slurm_executor() -> run.SlurmExecutor: ssh_tunnel = SSHTunnel( host=os.environ["SLURM_HOST"], user=os.environ["SLURM_USER"], job_dir=os.environ["SLURM_JOBDIR"], ) packager = GitArchivePackager() launcher = "torchrun" executor = SlurmExecutor( account=os.environ["SLURM_ACCT"], partition=os.environ["SLURM_PARTITION"], nodes=1, ntasks_per_node=1, tunnel=ssh_tunnel, container_image=os.environ["BASE_IMAGE"], time="00:30:00", packager=packager, launcher=launcher, ) return executor ... your_executor = your_slurm_executor() """ HEAD_NODE_IP_VAR = "head_node_ip" NPROC_PER_NODE_VAR = "SLURM_NTASKS_PER_NODE" NUM_NODES_VAR = "SLURM_NNODES" NODE_RANK_VAR = "SLURM_NODEID" HET_GROUP_HOST_VAR = "het_group_host" #: List of sbatch flags in snake case SBATCH_FLAGS = [ "account", "acctg_freq", "array", "batch", "clusters", "constraint", "container", "container_id", "core_spec", "cpus_per_gpu", "cpus_per_task", "comment", "debug", "delay_boot", "dependency", "distribution", "error", "exclude", "exclusive", "export", "get_user_env", "gid", "gpu_bind", "gpu_freq", "gpus", "gpus_per_node", "gpus_per_socket", "gpus_per_task", "gres", "gres_flags", "help", "hold", "ignore_pbs", "input", "job_name", "kill_on_invalid_dep", "licenses", "mail_type", "mail_user", "mcs_label", "mem", "mem_bind", "mem_per_cpu", "mem_per_gpu", "mincpus", "network", "nice", "no_kill", "no_requeue", "nodefile", "nodelist", "nodes", "ntasks", "ntasks_per_core", "ntasks_per_gpu", "ntasks_per_node", "ntasks_per_socket", "open_mode", "output", "overcommit", "oversubscribe", "parsable", "partition", "power", "prefer", "priority", "profile", "propagate", "qos", "quiet", "reboot", "requeue", "reservation", "signal", "sockets_per_node", "spread_job", "switches", "test_only", "thread_spec", "threads_per_core", "time", "time_min", "tmp", "tres_bind", "tres_per_task", "uid", "usage", "verbose", "version", "wait", "wait_all_nodes", "wckey", "wrap", "segment", ] SRUN_ARGS = [ "account", "partition", "job-name", "time", "nodes", "ntasks", "ntasks-per-node", "cpus-per-task", "gpus-per-node", "gpus-per-task", "qos", "mem", "mem-per-gpu", "mem-per-cpu", "comment", "constraint", "exclude", "gres", "exclusive", "array", "additional-parameters", "container-image", "container-mounts", "container-workdir", ] ALLOC_ARGS = [ "account", "partition", "job-name", "time", "nodes", "ntasks-per-node", "qos", "mem", "mem-per-gpu", "mem-per-cpu", ] @dataclass(kw_only=True) class ResourceRequest: packager: Packager nodes: int ntasks_per_node: int container_image: Optional[str] = None gpus_per_node: Optional[int] = None gpus_per_task: Optional[int] = None container_mounts: list[str] = field(default_factory=list) container_env: Optional[list[str]] = None env_vars: dict[str, str] = field(default_factory=dict) srun_args: Optional[list[str]] = None job_details: SlurmJobDetails = field(default_factory=SlurmJobDetails) het_group_index: Optional[int] = None account: str partition: Optional[str] = None job_name_prefix: Optional[str] = None time: str = "00:10:00" nodes: int = 1 ntasks_per_node: int = 1 cpus_per_task: Optional[int] = None cpus_per_gpu: Optional[int] = None gpus_per_node: Optional[int] = None gpus_per_task: Optional[int] = None qos: Optional[str] = None mem: Optional[str] = None mem_per_gpu: Optional[str] = None mem_per_cpu: Optional[str] = None comment: Optional[str] = None constraint: Optional[str] = None exclude: Optional[str] = None gres: Optional[str] = None signal: Optional[str] = None exclusive: Optional[Union[bool, str]] = None array: Optional[str] = None open_mode: str = "append" container_image: Optional[str] = None container_mounts: list[str] = field(default_factory=list) container_env: Optional[list[str]] = None additional_parameters: Optional[dict[str, Any]] = None srun_args: Optional[list[str]] = None heterogeneous: bool = False memory_measure: bool = False job_details: SlurmJobDetails = field(default_factory=SlurmJobDetails) tunnel: Union[SSHTunnel, LocalTunnel] = field(default_factory=lambda: LocalTunnel(job_dir="")) packager: Packager = field(default_factory=lambda: GitArchivePackager()) # type: ignore #: List of TorchX app handles that will be parsed and passed to --dependency flag in sbatch. dependencies: list[str] = field(default_factory=list) dependency_type: str = "afterok" #: Optional parameter to explicitly specify nproc_per_node for torchrun like components if the slurm cluster doesn't support granular resource allocation. torchrun_nproc_per_node: Optional[int] = None wait_time_for_group_job: int = 30 monitor_group_job: bool = True monitor_group_job_wait_time: int = 60 setup_lines: Optional[str] = None het_group_indices: Optional[list[int]] = None segment: Optional[int] = None network: Optional[str] = None #: Set by the executor; cannot be initialized job_name: str = field(init=False, default="nemo-job") stderr_to_stdout: bool = field(init=False, default=True) resource_group: list[ResourceRequest] = field(init=False, default_factory=list) run_as_group: bool = field(init=False, default=False) @classmethod def merge( cls: Type["SlurmExecutor"], executors: list["SlurmExecutor"], num_tasks: int ) -> "SlurmExecutor": assert len(executors) in [1, num_tasks] if len(executors) == 1 and not executors[0].heterogeneous: executors[0].run_as_group = True return executors[0] if len(executors) == 1: executors = executors * num_tasks main_executor = executors[0] main_executor.run_as_group = True if main_executor.het_group_indices: assert main_executor.heterogeneous, ( "heterogeneous must be True if het_group_indices is provided" ) assert len(main_executor.het_group_indices) == num_tasks, ( "het_group_indices must be the same length as the number of tasks" ) assert all( x <= y for x, y in zip( main_executor.het_group_indices, main_executor.het_group_indices[1:] ) ), "het_group_indices must be equal or increasing than previous" main_executor.resource_group = [ cls.ResourceRequest( packager=copy.deepcopy(main_executor.packager), nodes=main_executor.nodes, ntasks_per_node=main_executor.ntasks_per_node, container_image=copy.deepcopy(main_executor.container_image), container_mounts=copy.deepcopy(main_executor.container_mounts), container_env=copy.deepcopy(main_executor.container_env), env_vars=copy.deepcopy(main_executor.env_vars), gpus_per_node=main_executor.gpus_per_node, gpus_per_task=main_executor.gpus_per_task, srun_args=main_executor.srun_args, job_details=copy.deepcopy(main_executor.job_details), het_group_index=main_executor.het_group_indices[0] if main_executor.het_group_indices else None, ) ] for i, executor in enumerate(executors[1:]): main_executor.resource_group.append( cls.ResourceRequest( packager=copy.deepcopy(executor.packager), nodes=executor.nodes, ntasks_per_node=executor.ntasks_per_node, container_image=copy.deepcopy(executor.container_image), container_mounts=copy.deepcopy(executor.container_mounts), container_env=copy.deepcopy(executor.container_env), env_vars=copy.deepcopy(executor.env_vars), gpus_per_node=executor.gpus_per_node, gpus_per_task=executor.gpus_per_task, srun_args=executor.srun_args, job_details=copy.deepcopy(executor.job_details), het_group_index=main_executor.het_group_indices[i + 1] if main_executor.het_group_indices else None, ) ) return main_executor def __post_init__(self): if self.wait_time_for_group_job < 0: self.wait_time_for_group_job = 0 def info(self) -> str: return f"{self.__class__.__qualname__} on {self.tunnel.key}" def alloc(self, job_name="interactive"): self.job_name = f"{self.job_name_prefix}{job_name}" args = [ f"--{arg}={getattr(self, arg.replace('-', '_'))}" for arg in self.ALLOC_ARGS if getattr(self, arg.replace("-", "_"), None) is not None ] self.slurm.run( f"salloc {' '.join(args)} && cd {self.job_dir}", hide=False, echo=True, pty=True, ) def srun( self, cmd: str, job_name="interactive", flags=None, env_vars: Optional[Dict[str, str]] = None, arg_dict=None, **kwargs, ): self.job_name = f"{self.job_name_prefix}{job_name}" _arg_dict = { arg: getattr(self, arg.replace("-", "_")) for arg in self.SRUN_ARGS if getattr(self, arg.replace("-", "_"), None) is not None } _arg_dict["container-mounts"] = ",".join(self.container_mounts) if env_vars: _arg_dict["container-env"] = ",".join(list(env_vars.keys())) if arg_dict: _arg_dict.update(arg_dict) add_quotes = ["container-image", "container-mounts", "container-workdir"] if env_vars: add_quotes.append("container-env") args = [] for arg, value in _arg_dict.items(): if arg in add_quotes: args.append(f"--{arg}={shlex.quote(value)}") else: args.append(f"--{arg}={value}") if flags: args.extend(flags) srun = f"srun {' '.join(args)} {cmd}" if env_vars: srun = ( " ".join([f"{key}={shlex.quote(val)}" for key, val in env_vars.items()]) + " " + srun ) return self.slurm.run(srun, **kwargs) def bash(self, job_name="interactive"): self.srun("bash", job_name=job_name) def launch_devspace( self, space: DevSpace, job_name="interactive", env_vars: Optional[Dict[str, str]] = None, add_workspace_to_pythonpath: bool = True, ): cfg_zlib = ZlibJSONSerializer().serialize(space.__io__) _container_dir = f"/workspaces/{space.name}" mounts = self.container_mounts mounts.append(f"{self.job_dir}:{_container_dir}") if add_workspace_to_pythonpath: mounts.append(f"{self.job_dir}:/workspaces/.main") arg_dict = {} arg_dict["container-workdir"] = _container_dir arg_dict["container-mounts"] = ",".join(mounts) if self.local_is_slurm: srun_kwargs = dict(hide=False, echo=True, pty=True) else: srun_kwargs = dict(warn=True, hide=False, echo=False, asynchronous=True) _env_vars = env_vars or {} _env_vars["NEMO_DEVSPACE"] = space.name srun = self.srun( f"nemorun devspace sshserver {cfg_zlib}", job_name=job_name, env_vars=_env_vars, flags=["--no-container-remap-root"], arg_dict=arg_dict, **srun_kwargs, ) if not self.local_is_slurm: return SlurmTunnelCallback(self, space=space, srun=srun) def connect_devspace(self, space, tunnel_dir=None): return SlurmTunnelCallback(self, space=space, tunnel_dir=tunnel_dir)
[docs] def assign( self, exp_id: str, exp_dir: str, task_id: str, task_dir: str, ): self.job_name = task_id self.experiment_dir = exp_dir self.job_dir = os.path.join(exp_dir, task_dir) self.experiment_id = exp_id self.tunnel._set_job_dir(self.experiment_id)
def get_launcher_prefix(self) -> Optional[list[str]]: launcher = self.get_launcher() if launcher.nsys_profile: nsys_prefix = launcher.get_nsys_prefix(profile_dir=f"/{RUNDIR_NAME}") if launcher.nsys_gpu_metrics: nsys_prefix += ["$GPU_METRICS_FLAG"] return nsys_prefix def get_nsys_entrypoint(self) -> str: launcher = self.get_launcher() entrypoint, postfix = "nsys", "" if launcher.nsys_gpu_metrics: entrypoint = 'bash -c \'GPU_METRICS_FLAG=""; if echo "${GPU_METRICS_NODES}" | grep -q -w "${SLURM_NODEID}"; then GPU_METRICS_FLAG="--gpu-metrics-devices=${SLURM_LOCALID}"; fi; nsys' postfix = "'" return (entrypoint, postfix) def supports_launcher_transform(self) -> bool: return True if isinstance(self.get_launcher(), SlurmTemplate) else False def package_configs(self, *cfgs: tuple[str, str]) -> list[str]: filenames = [] basepath = os.path.join(self.job_dir, "configs") for name, cfg in cfgs: filename = os.path.join(basepath, name) os.makedirs(os.path.dirname(filename), exist_ok=True) with open(filename, "w") as f: f.write(cfg) filenames.append( os.path.join( "/", RUNDIR_NAME, "configs", name, ) ) return filenames def package(self, packager: Packager, job_name: str): assert self.experiment_id, "Executor not assigned to an experiment." if ( get_packaging_job_key(self.experiment_id, job_name) in self.tunnel.packaging_jobs and not packager.symlink_from_remote_dir ): logger.info( f"Packaging for job {job_name} in tunnel {self.tunnel.key} already done. Skipping subsequent packagings.\n" "This may cause issues if you have multiple tasks with the same name but different packagers, as only the first packager will be used." ) return if packager.symlink_from_remote_dir: logger.info( f"Packager {get_packaging_job_key(self.experiment_id, job_name)} is configured to symlink from remote dir. Skipping packaging." ) if type(packager) is Packager: self.tunnel.packaging_jobs[get_packaging_job_key(self.experiment_id, job_name)] = ( PackagingJob(symlink=False) ) return self.tunnel.packaging_jobs[get_packaging_job_key(self.experiment_id, job_name)] = ( PackagingJob( symlink=True, src_path=packager.symlink_from_remote_dir, dst_path=os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"), ) ) # Tunnel job dir is the directory of the experiment id, so the base job dir is two levels up base_remote_dir = str(Path(self.tunnel.job_dir).parent.parent) base_remote_mount = f"{base_remote_dir}:{base_remote_dir}" if base_remote_mount not in self.container_mounts: self.container_mounts.append(f"{base_remote_dir}:{base_remote_dir}") for req in self.resource_group: if base_remote_mount not in req.container_mounts: req.container_mounts.append(base_remote_mount) return if isinstance(packager, GitArchivePackager): output = subprocess.run( ["git", "rev-parse", "--show-toplevel"], check=True, stdout=subprocess.PIPE, ) path = output.stdout.splitlines()[0].decode() base_path = Path(path).absolute() else: base_path = Path(os.getcwd()).absolute() local_pkg = packager.package(base_path, self.job_dir, job_name) local_code_extraction_path = os.path.join(self.job_dir, "code") ctx = Context() ctx.run(f"mkdir -p {local_code_extraction_path}") if self.get_launcher().nsys_profile: remote_nsys_extraction_path = os.path.join( self.job_dir, self.get_launcher().nsys_folder ) ctx.run(f"mkdir -p {remote_nsys_extraction_path}") # Touch hidden init file ctx.run(f"touch {remote_nsys_extraction_path}/.init") if local_pkg: ctx.run( f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True ) self.tunnel.packaging_jobs[get_packaging_job_key(self.experiment_id, job_name)] = ( PackagingJob( symlink=False, dst_path=None if type(packager) is Packager else os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"), ) )
[docs] def parse_deps(self) -> list[str]: """ Helper function to parse a list of TorchX app handles and return a list of Slurm Job IDs to use as dependencies. """ deps = [] for dep in self.dependencies: # Parse torchx app handle to get slurm job id _, _, path_str = dep.partition("://") # path is of the form ["", "app_id", "master", "0"] path = path_str.split("/") job_id = path[1] deps.append(job_id) return deps
[docs] def nnodes(self) -> int: return self.nodes if isinstance(self.nodes, int) else self.nodes[0]
[docs] def nproc_per_node(self) -> int: if self.torchrun_nproc_per_node: return self.torchrun_nproc_per_node if self.gpus_per_node and self.ntasks_per_node == 1: return self.gpus_per_node if self.gpus_per_task: return self.gpus_per_task return ( self.ntasks_per_node if isinstance(self.ntasks_per_node, int) else self.ntasks_per_node[0] )
[docs] def macro_values(self) -> Optional[ExecutorMacros]: return ExecutorMacros( head_node_ip_var=self.HEAD_NODE_IP_VAR, nproc_per_node_var=self.NPROC_PER_NODE_VAR, num_nodes_var=self.NUM_NODES_VAR, node_rank_var=self.NODE_RANK_VAR, het_group_host_var=self.HET_GROUP_HOST_VAR, )
def _setup_launcher(self): super()._setup_launcher() launcher = self.launcher if launcher and isinstance(launcher, (FaultTolerance, Torchrun)): self.torchrun_nproc_per_node = self.torchrun_nproc_per_node or self.ntasks_per_node self.ntasks_per_node = 1 CONSOLE.log( f"Detected {launcher.__class__.__name__} launcher, setting ntasks_per_node=1 and torchrun_nproc_per_node={self.torchrun_nproc_per_node}" ) if launcher and isinstance(launcher, FaultTolerance): base_dir = os.path.join(self.tunnel.job_dir, Path(self.job_dir).name) launcher.cfg_path = os.path.join(base_dir, f"{self.job_name}_ft_cfg.yml") launcher.finished_flag_file = os.path.join( "/", RUNDIR_NAME, f"{self.job_name}_finished_flag" ) launcher.job_results_file = os.path.join(base_dir, f"{self.job_name}_job_results") @property def local(self) -> LocalTunnel: if not hasattr(self, "_local"): self._local = LocalTunnel(job_dir=self.tunnel.job_dir) return self._local @property def slurm(self) -> Tunnel: if self.local_is_slurm: return self.local self.tunnel.connect() return self.tunnel @property def local_is_slurm(self) -> bool: try: self.local.run("which srun", hide=True) return True except invoke.exceptions.UnexpectedExit: return False
def _as_sbatch_flag(key: str, value: Any) -> str: """Convert key value pairs to `#SBATCH --{key}={value}` flags""" key = key.replace("_", "-") if value is True: return f"#SBATCH --{key}" value = shlex.quote(str(value)) return f"#SBATCH --{key}={value}" @dataclass(kw_only=True) class SlurmBatchRequest: launch_cmd: list[str] jobs: list[str] command_groups: list[list[str]] executor: SlurmExecutor max_retries: int setup: Optional[list[str]] = None extra_env: dict[str, str] launcher: Optional[Launcher] = None def materialize(self) -> str: """Creates the content of an sbatch file with provided parameters Parameters ---------- See slurm sbatch documentation for most parameters: https://slurm.schedmd.com/sbatch.html Below are the parameters that differ from slurm documentation: command_groups: each command group will be assigned one srun folder: str/Path folder where print logs and error logs will be written setup: list a list of command to run in sbatch before running srun additional_parameters: dict Forces any parameter to a given value in sbatch. This can be useful to add parameters which are not currently available in nemo_launcher. Eg: {"mail-user": "blublu@nvidia.com", "mail-type": "BEGIN"} srun_args: List[str] Add each argument in the list to the srun call Raises ------ ValueError In case an erroneous keyword argument is added, a list of all eligible parameters is printed, with their default values """ args = asdict(self.executor) # noqa: F821 parameters = { k: v for k, v in args.items() if v is not None and k in SlurmExecutor.SBATCH_FLAGS } # rename and reformat parameters if "cpus_per_gpu" in parameters and "gpus_per_task" not in parameters: warnings.warn( # noqa: F821 '"cpus_per_gpu" requires to set "gpus_per_task" to work (and not "gpus_per_node")' ) # add necessary parameters original_job_name: str = self.jobs[0] # type: ignore job_name_prefix = ( self.executor.job_name_prefix if self.executor.job_name_prefix else f"{self.executor.account}-{self.executor.account.split('_')[-1]}." ) job_name = f"{job_name_prefix}{original_job_name}" slurm_job_dir = ( self.executor.tunnel.job_dir if self.executor.tunnel else self.executor.job_dir ) job_directory_name = Path(self.executor.job_dir).name job_details = self.executor.job_details if not job_details.job_name: job_details.job_name = job_name if not job_details.folder: job_details.folder = os.path.join(slurm_job_dir, job_directory_name) parameters["job_name"] = job_details.job_name stdout = str(job_details.stdout) stderr = str(job_details.stderr) if self.executor.array is not None: stdout = stdout.replace("%j", "%A_%a") stderr = stderr.replace("%j", "%A_%a") parameters["output"] = stdout.replace("%t", "0") if not self.executor.stderr_to_stdout: parameters["error"] = stderr.replace("%t", "0") if self.executor.additional_parameters is not None: parameters.update(self.executor.additional_parameters) # now create sbatch_cmd = " ".join([shlex.quote(arg) for arg in self.launch_cmd]) sbatch_flags = [] if self.executor.heterogeneous: assert len(self.jobs) == len(self.executor.resource_group), ( f"Number of jobs {len(self.jobs)} must match number of resource group requests {len(self.executor.resource_group)}.\nIf you are just submitting a single job, make sure that heterogeneous=False in the executor." ) final_group_index = len(self.executor.resource_group) - 1 if self.executor.het_group_indices: final_group_index = self.executor.het_group_indices.index( max(self.executor.het_group_indices) ) for i in range(len(self.executor.resource_group)): resource_req = self.executor.resource_group[i] if resource_req.het_group_index is not None: assert self.executor.resource_group[i - 1].het_group_index is not None, ( "het_group_index must be set for all requests in resource_group" ) if ( i > 0 and resource_req.het_group_index == self.executor.resource_group[i - 1].het_group_index ): continue het_parameters = parameters.copy() het_parameters["output"] = parameters["output"].replace( original_job_name, self.jobs[i] ) if "error" in parameters: het_parameters["error"] = parameters["error"].replace( original_job_name, self.jobs[i] ) het_parameters.update( { "job_name": f"{job_details.job_name[:-2] if job_details.job_name.endswith('-0') else job_details.job_name}-{i}", "nodes": resource_req.nodes, "ntasks_per_node": resource_req.ntasks_per_node, "gpus_per_node": resource_req.gpus_per_node, "gpus_per_task": resource_req.gpus_per_task, } ) for k in sorted(parameters): sbatch_flags.append(_as_sbatch_flag(k, het_parameters[k])) if i != final_group_index: sbatch_flags.append("#SBATCH hetjob") else: for k in sorted(parameters): sbatch_flags.append(_as_sbatch_flag(k, parameters[k])) if self.executor.dependencies: slurm_deps = self.executor.parse_deps() sbatch_flags.append( _as_sbatch_flag( "dependency", f"{self.executor.dependency_type}:{':'.join(slurm_deps)}" ) ) env_vars = [] full_env_vars = self.executor.env_vars | self.extra_env for key, value in full_env_vars.items(): env_vars.append(f"export {key.upper()}={value}") # commandline (this will run the function and args specified in the file provided as argument) # We pass --output and --error here, because the SBATCH command doesn't work as expected with a filename pattern # Removed redundant assignment to stderr_flags srun_commands = [] group_env_vars = [] srun_stdout = noquote(job_details.srun_stdout) stderr_flags = ( [] if self.executor.stderr_to_stdout else ["--error", noquote(job_details.srun_stderr)] ) memory_measure_out = None if self.executor.memory_measure: memory_measure_out = srun_stdout def get_container_flags( base_mounts: list[str], src_job_dir: str, container_image: Optional[str], container_env: Optional[list[str]] = None, ) -> list[str]: _container_flags = ["--container-image", container_image] if container_image else [] new_mounts = copy.deepcopy(base_mounts) for i, mount in enumerate(new_mounts): if mount.startswith(RUNDIR_SPECIAL_NAME): new_mounts[i] = mount.replace(RUNDIR_SPECIAL_NAME, src_job_dir, 1) new_mounts.append(f"{src_job_dir}:/{RUNDIR_NAME}") _mount_arg = ",".join(new_mounts) _container_flags += ["--container-mounts", _mount_arg] _container_flags += [ "--container-workdir", f"/{RUNDIR_NAME}/code", ] if container_env: _container_flags += ["--container-env", ",".join(container_env)] return _container_flags for group_ind, command_group in enumerate(self.command_groups): if self.executor.run_as_group and len(self.executor.resource_group) == len( self.command_groups ): resource_req = self.executor.resource_group[group_ind] if not resource_req.job_details.job_name: resource_req.job_details.job_name = f"{job_name_prefix}{self.jobs[group_ind]}" if not resource_req.job_details.folder: resource_req.job_details.folder = os.path.join( slurm_job_dir, job_directory_name ) cmd_stdout = noquote(resource_req.job_details.srun_stdout) cmd_stderr = ( [] if self.executor.stderr_to_stdout else [ "--error", noquote(resource_req.job_details.srun_stderr), ] ) current_env_vars = [] for key, value in resource_req.env_vars.items(): current_env_vars.append(f"export {key.upper()}={value}") group_env_vars.append(current_env_vars) _container_flags = get_container_flags( base_mounts=resource_req.container_mounts, src_job_dir=os.path.join( slurm_job_dir, job_directory_name, ), container_image=resource_req.container_image, container_env=resource_req.container_env, ) _srun_args = ["--wait=60", "--kill-on-bad-exit=1"] _srun_args.extend(resource_req.srun_args or []) else: cmd_stdout = srun_stdout.replace(original_job_name, self.jobs[group_ind]) cmd_stderr = stderr_flags.copy() if cmd_stderr: cmd_stderr[-1] = cmd_stderr[-1].replace(original_job_name, self.jobs[group_ind]) _container_flags = get_container_flags( base_mounts=self.executor.container_mounts, src_job_dir=os.path.join( slurm_job_dir, job_directory_name, ), container_image=self.executor.container_image, container_env=self.executor.container_env, ) _srun_args = ["--wait=60", "--kill-on-bad-exit=1"] _srun_args.extend(self.executor.srun_args or []) if self.executor.run_as_group and self.executor.heterogeneous: het_group_index = ( self.executor.resource_group[group_ind].het_group_index if self.executor.resource_group[group_ind].het_group_index is not None else group_ind ) het_group_flag = [f"--het-group={het_group_index}"] else: het_group_flag = [] srun_cmd = " ".join( list( map( lambda arg: arg if isinstance(arg, noquote) else shlex.quote(arg), [ "srun", *het_group_flag, "--output", cmd_stdout, *cmd_stderr, *_container_flags, *_srun_args, ], ) ) ) command = " ".join(command_group) if self.executor.run_as_group: srun_command = f"{srun_cmd} {command} & pids[{group_ind}]=$!" if group_ind != len(self.command_groups) - 1: srun_command += f"\n\nsleep {self.executor.wait_time_for_group_job}\n" else: srun_command = f"{srun_cmd} {command}" srun_commands.append(srun_command) vars_to_fill = { "sbatch_command": sbatch_cmd, "sbatch_flags": sbatch_flags, "max_retries": self.max_retries, "env_vars": env_vars, "head_node_ip_var": SlurmExecutor.HEAD_NODE_IP_VAR, "setup_lines": self.executor.setup_lines, "memory_measure": memory_measure_out, "srun_commands": srun_commands, "group_env_vars": group_env_vars, "heterogeneous": self.executor.heterogeneous, "run_as_group": self.executor.run_as_group, "monitor_group_job": self.executor.run_as_group and self.executor.monitor_group_job, "monitor_group_job_wait_time": self.executor.monitor_group_job_wait_time, "het_group_host_var": SlurmExecutor.HET_GROUP_HOST_VAR, "ft_enabled": self.launcher and isinstance(self.launcher, FaultTolerance), } if self.launcher and isinstance(self.launcher, FaultTolerance): assert ( self.launcher.cfg_path and self.launcher.finished_flag_file and self.launcher.job_results_file ) vars_to_fill["fault_tol_cfg_path"] = self.launcher.cfg_path vars_to_fill["fault_tol_finished_flag_file"] = self.launcher.finished_flag_file vars_to_fill["fault_tol_job_results_file"] = self.launcher.job_results_file sbatch_script = fill_template("slurm.sh.j2", vars_to_fill) return sbatch_script def __repr__(self) -> str: return f"""{" ".join(self.launch_cmd + ["$SBATCH_SCRIPT"])} #---------------- # SBATCH_SCRIPT #---------------- {self.materialize()}""" class SlurmTunnelCallback(Callback): def __init__(self, executor: SlurmExecutor, space: DevSpace, srun=None, tunnel_dir=None): self.executor = executor self.srun = srun self.space = space self.ssh_config = SSHConfigFile() self.console = Console() self.editor_started = False self.tunnel_dir = tunnel_dir def on_start(self): if self.srun is not None: self.srun_status = self.console.status( Text("srun: ", style="bold green"), spinner="dots" ) self.srun_status.start() self.srun_is_done = False else: self.srun_is_done = True def on_interval(self): from nemo_run.devspace.editor import launch_editor if not self.srun_is_done: status = self.srun.runner.stderr[-1] if self.srun.runner.stderr else None stdout = self.srun.runner.stdout if stdout: for line in stdout: if ( "To connect to the tunnel, run the following command on your local machine:" in line ): if not self.srun_is_done: self.srun_is_done = True self.srun_status.stop() self.console.log(":white_check_mark: Server is launched") self.console.log("[bold green]Devspace is active...") if not self.srun_is_done and status: self.srun_status.update(Text(status, style="bold green")) elif not self.editor_started: _tunnel_dir = self.tunnel_dir or server_dir(self.executor.job_dir, self.space.name) metadata = TunnelMetadata.restore(_tunnel_dir, tunnel=self.tunnel) self.forward_port_context = self.tunnel.session.forward_local( int(metadata.port), remote_host=metadata.hostname ) self.forward_port_context.__enter__() self.ssh_config.add_entry( metadata.user, "localhost", int(metadata.port), self.tunnel_name ) self.ssh_entry_added = True with self.console.status("Setting up port forwarding", spinner="dots"): time.sleep(3) self.console.print( f"[bold green]:white_check_mark: Port forwarding established. " f"Connect via SSH with: ssh tunnel.{self.tunnel_name}" ) launch_editor(self.tunnel_name, f"/workspaces/{metadata.workspace_name}") self.editor_started = True def on_stop(self): # if hasattr(self, "forward_port_context"): # self.forward_port_context.__exit__() if hasattr(self, "ssh_entry_added"): self.ssh_config.remove_entry(self.tunnel_name) @property def tunnel_name(self) -> str: workspace_name = self.space.name return ".".join([workspace_name, self.space.name])