# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextvars
import copy
import importlib.util
import inspect
import json
import os
import pprint
import shutil
import sys
import time
import traceback
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional, Type, Union
import fiddle as fdl
import networkx as nx
import rich
from fiddle._src import daglish, diffing
from rich.console import Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskID, TimeElapsedColumn
from rich.progress import Task as RichTask
from rich.syntax import Syntax
from torchx.specs.api import AppState
import nemo_run as run
from nemo_run.config import (
Config,
ConfigurableMixin,
Partial,
Script,
get_nemorun_home,
get_type_namespace,
)
from nemo_run.core.execution.base import Executor
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
from nemo_run.core.execution.docker import DockerExecutor
from nemo_run.core.execution.lepton import LeptonExecutor
from nemo_run.core.execution.local import LocalExecutor
from nemo_run.core.execution.skypilot import SkypilotExecutor
from nemo_run.core.execution.skypilot_jobs import SkypilotJobsExecutor
from nemo_run.core.execution.slurm import SlurmExecutor
from nemo_run.core.frontend.console.api import CONSOLE, configure_logging, deconfigure_logging
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.core.tunnel.client import SSHTunnel, Tunnel
from nemo_run.core.tunnel.rsync import rsync
from nemo_run.run.job import Job, JobGroup
from nemo_run.run.plugin import ExperimentPlugin
from nemo_run.run.torchx_backend.runner import get_runner
from nemo_run.run.utils import TeeStdoutStderr
_current_experiment: contextvars.ContextVar["Experiment"] = contextvars.ContextVar(
"nemo_current_experiment"
)
class DummyConsole:
"""A dummy console that mimics rich.console.Console but does nothing."""
def __getattr__(self, name):
"""Return a no-op function for any attribute access."""
def no_op(*args, **kwargs):
pass
return no_op
[docs]
class Experiment(ConfigurableMixin):
"""
A context manager to launch and manage multiple runs, all using pure Python.
run.Experiment provides researchers with
a simple and flexible way to create and manage their ML experiments.
Building on the core blocks of nemo_run,
the Experiment can be used as an umbrella under which a user can
launch different configured functions on multiple remote clusters.
The Experiment takes care of storing the run metadata,
launching it on the specified cluster, and syncing the logs and artifacts.
Additionally, the Experiment also provides management tools to easily inspect and reproduce past experiments.
Some of the use-cases that it enables are listed below:
1. Check the status and logs of a past experiment
2. Reconstruct a past experiment and relaunch it after some changes
3. Compare different runs of the same experiment.
This API allows users to programmatically define their experiments.
To get a glance of the flexibility provided, here are some use cases
which can be supported by the Experiment in just a few lines of code.
1. Launch a benchmarking run on different GPUs at the same time in parallel
2. Launch a sequential data processing pipeline on a CPU heavy cluster
3. Launch hyperparameter grid search runs on a single cluster in parallel
4. Launch hyperparameter search runs distributed across all available clusters
The design is heavily inspired from `XManager <https://github.com/google-deepmind/xmanager/blob/main/docs/xm_launch_api_principles.md>`_.
Under the hood, the Experiment metadata is stored in the local filesystem
inside a user specified directory controlled by get_nemorun_home() env var.
We will explore making the metadata more persistent in the future.
.. note::
`Experiment.add` and `Experiment.run` methods inside Experiment can currently only be used within its context manager.
Examples
--------
.. code-block:: python
# An experiment that runs a pre-configured training example
# on multiple GPU specific clusters (A100 and H100 shown here) in parallel using torchrun
# Assumes that example_to_run is pre-configured using run.Partial
with run.Experiment("example-multiple-gpus", executor="h100_cluster") as exp:
# Set up the run on H100
# Setting up a single task is identical to setting up a single run outside the experiment
h100_cluster: run.SlurmExecutor = exp.executor.clone()
h100_cluster.nodes = 2
# torchrun manages the processes on a single node
h100_cluster.ntasks_per_node = 1
h100_cluster.gpus_per_task = 8
h100_cluster.packager.subpath = "subpath/to/your/code/repo"
h100_cluster.launcher = "torchrun"
exp.add(
"example_h100",
fn=example_to_run,
tail_logs=True,
executor=h100_cluster,
)
# Set up the run on A100
a100_cluster: run.Config[SlurmExecutor] = h100_cluster.clone()
a100_cluster.tunnel = run.Config(
SSHTunnel,
host=os.environ["A100_HOST"],
user="your_user_in_cluster",
identity="path_to_your_ssh_key"
)
exp.add(
"example_a100",
fn=example_to_run,
tail_logs=True,
executor=a100_cluster,
)
# Runs all the task in the experiment.
# By default, all tasks will be run in parallel if all different executors support parallel execution.
# You can set sequential=True to run the tasks sequentially.
exp.run()
# Upon exiting the context manager, the Experiment will automatically wait for all tasks to complete,
# and optionally tail logs for tasks that have tail_logs=True.
# A detach mode (if the executors support it) will be available soon.
# Once all tasks have completed,
# the Experiment will display a status table and clean up resources like ssh tunnels.
# You can also manage the experiment at a later point in time
exp = run.Experiment.from_title("example-multiple-gpus")
exp.status()
exp.logs(task_id="example_a100")
"""
GOODBYE_MESSAGE_PYTHON = """
# The experiment was run with the following tasks: {tasks}
# You can inspect and reconstruct this experiment at a later point in time using:
experiment = run.Experiment.from_id("{exp_id}")
experiment.status() # Gets the overall status
experiment.logs("{tasks[0]}") # Gets the log for the provided task
experiment.cancel("{tasks[0]}") # Cancels the provided task if still running
"""
GOODBYE_MESSAGE_BASH = """
# You can inspect this experiment at a later point in time using the CLI as well:
nemo experiment status {exp_id}
nemo experiment logs {exp_id} 0
nemo experiment cancel {exp_id} 0
"""
_PARALLEL_SUPPORTED_EXECUTORS = (
SlurmExecutor,
LocalExecutor,
SkypilotExecutor,
SkypilotJobsExecutor,
DockerExecutor,
DGXCloudExecutor,
LeptonExecutor,
)
_DETACH_SUPPORTED_EXECUTORS = (
SlurmExecutor,
SkypilotExecutor,
SkypilotJobsExecutor,
DGXCloudExecutor,
LeptonExecutor,
)
_DEPENDENCY_SUPPORTED_EXECUTORS = (SlurmExecutor,)
_RUNNER_DEPENDENT_EXECUTORS = (LocalExecutor,)
_CONFIG_FILE = "_CONFIG"
_VERSION_FILE = "_VERSION"
_TASK_FILE = "_TASKS"
_DONE_FILE = "_DONE"
_TUNNELS_FILE = "_TUNNELS"
_current_experiment_token: Optional[contextvars.Token]
[docs]
@classmethod
def catalog(
cls: Type["Experiment"],
title: str = "",
) -> list[str]:
"""
List all experiments inside get_nemorun_home(), optionally with the provided title.
"""
parent_dir = os.path.join(get_nemorun_home(), "experiments", title)
return _get_sorted_dirs(parent_dir)
@classmethod
def _from_config(cls: Type["Experiment"], exp_dir: str) -> "Experiment":
id = os.path.basename(exp_dir)
with open(os.path.join(exp_dir, cls._CONFIG_FILE), "r") as f:
config = f.read()
if not config:
raise ValueError(f"Experiment {id} not found.")
serializer = ZlibJSONSerializer()
cfg: Config["Experiment"] = fdl.cast(Config, serializer.deserialize(config))
if "id" not in cfg.__arguments__:
cfg.id = id
cfg._reconstruct = True
exp: "Experiment" = fdl.build(cfg)
exp._jobs = exp._load_jobs()
try:
exp.tunnels = exp._load_tunnels()
except Exception as e:
exp.console.log(
f"Exception {e} loading tunnels for experiment {id}, will continue without loading tunnels."
)
return exp
[docs]
@classmethod
def from_id(
cls: Type["Experiment"],
id: str,
) -> "Experiment":
"""
Reconstruct an experiment with the specified id.
"""
title, _, _ = id.rpartition("_")
parent_dir = os.path.join(get_nemorun_home(), "experiments", title)
exp_dir = os.path.join(parent_dir, id)
assert os.path.isdir(exp_dir), f"Experiment {id} not found."
exp = cls._from_config(exp_dir)
return exp
[docs]
@classmethod
def from_title(
cls: Type["Experiment"],
title: str,
) -> "Experiment":
"""
Reconstruct an experiment with the specified title.
"""
parent_dir = os.path.join(get_nemorun_home(), "experiments", title)
exp_dir = _get_latest_dir(parent_dir)
assert os.path.isdir(exp_dir), f"Experiment {id} not found."
exp = cls._from_config(exp_dir)
return exp
def __init__(
self,
title: str,
executor: Executor | None = None, # type: ignore
id: str | None = None,
log_level: str = "INFO",
_reconstruct: bool = False,
jobs: list[Job | JobGroup] | None = None,
base_dir: str | None = None,
clean_mode: bool = False,
enable_goodbye_message: bool = True,
threadpool_workers: int = 16,
skip_status_at_exit: bool = False,
serialize_metadata_for_scripts: bool = True,
) -> None:
"""
Initializes an experiment run by creating its metadata directory and saving the experiment config.
Args:
title: Title or name for the experiment
executor: Any executor that subclasses run.Executor and is supported by NeMo-Run.
This will be used as the default executor for tasks if an explicit one is not specified.
Users can also clone this and make task specific executor changes.
id (Optional): Unique id for the experiment run.
If not specified, will be set automatically based on the current timestamp.
log_level: Set log level for the experiment. Defaults to WARN.
_reconstruct: Generally, the user does not need to specify this flag.
This is only set to True when using run.Experiment.from_dir.
clean_mode: If True, disables all console output (logs, progress bars, etc.). Defaults to False.
enable_goodbye_message: if True, prints goodbye message after submitting job. Defaults to True.
"""
configure_logging(level=log_level)
self._reconstruct = _reconstruct
if _reconstruct:
assert id, "Cannot reconstruct an experiment without id."
self._title = title
self._id = id or f"{title}_{int(time.time())}"
self._enable_goodbye_message = enable_goodbye_message
self._threadpool_workers = threadpool_workers
self._skip_status_at_exit = skip_status_at_exit
self._serialize_metadata_for_scripts = serialize_metadata_for_scripts
base_dir = str(base_dir or get_nemorun_home())
self._exp_dir = os.path.join(base_dir, "experiments", title, self._id)
self.log_level = log_level
self._runner = get_runner(component_defaults=None, experiment=self)
if not _reconstruct:
self.executor = executor if executor else LocalExecutor()
else:
assert isinstance(executor, Executor)
self.executor = executor
self._jobs: list[Job | JobGroup] = jobs or []
self.tunnels: dict[str, Tunnel] = {}
self.console = CONSOLE
self.clean_mode = clean_mode
if self.clean_mode:
self.console = DummyConsole()
self._launched = False
self._live_progress = None
self._current_experiment_token = None
[docs]
def to_config(self) -> Config:
return Config(
self.__class__,
title=self._title,
id=self._id,
executor=self.executor.to_config(),
log_level=self.log_level,
clean_mode=self.clean_mode,
threadpool_workers=self._threadpool_workers,
enable_goodbye_message=self._enable_goodbye_message,
skip_status_at_exit=self._skip_status_at_exit,
serialize_metadata_for_scripts=self._serialize_metadata_for_scripts,
)
def _save_experiment(self, exist_ok: bool = False):
os.makedirs(self._exp_dir, exist_ok=exist_ok)
self._save_config()
def _save_config(self):
with open(os.path.join(self._exp_dir, self.__class__._CONFIG_FILE), "w+") as f:
f.write(ZlibJSONSerializer().serialize(self.to_config()))
with open(os.path.join(self._exp_dir, self.__class__._VERSION_FILE), "w+") as f:
f.write(f"{run.__version__}\n")
def _save_tunnels(self):
serializer = ZlibJSONSerializer()
serialized_tunnels = {
k: serializer.serialize(v.to_config()) for k, v in self.tunnels.items()
}
with open(os.path.join(self._exp_dir, self.__class__._TUNNELS_FILE), "w+") as f:
json.dump(serialized_tunnels, f)
def _load_tunnels(self) -> dict[str, Tunnel]:
with open(os.path.join(self._exp_dir, self.__class__._TUNNELS_FILE)) as f:
serialized_tunnels = json.load(f)
serializer = ZlibJSONSerializer()
return {k: fdl.build(serializer.deserialize(v)) for k, v in serialized_tunnels.items()}
def _save_jobs(self):
serialized_jobs = list(map(lambda job: job.serialize(), self.jobs))
with open(os.path.join(self._exp_dir, self.__class__._TASK_FILE), "w+") as f:
json.dump(serialized_jobs, f)
if "__main__" in sys.modules:
main_module = sys.modules["__main__"]
try:
with open(os.path.join(self._exp_dir, "__main__.py"), "w+") as f:
f.write(inspect.getsource(main_module))
except TypeError:
...
def _load_jobs(self) -> list[Job | JobGroup]:
with open(os.path.join(self._exp_dir, self._TASK_FILE)) as f:
serialized_jobs = json.load(f)
serializer = ZlibJSONSerializer()
jobs = []
for job_cfg, task_cfg in serialized_jobs:
job_cfg = serializer.deserialize(job_cfg)
job: Job | JobGroup = fdl.build(job_cfg)
if isinstance(job, Job):
job.task = task_cfg # type: ignore
elif isinstance(job, JobGroup):
job.tasks = task_cfg # type: ignore
else:
raise ValueError(f"Unknown task type: {task_cfg.__fn_or_cls__}")
jobs.append(job)
return jobs
def _prepare(self, exist_ok: bool = False):
self._save_experiment(exist_ok=exist_ok)
for job in self.jobs:
job.prepare(serialize_metadata_for_scripts=self._serialize_metadata_for_scripts)
self._save_jobs()
def _add_single_job(
self,
task: Union[Partial, Script],
executor: Executor,
name: str = "",
plugins: Optional[list[ExperimentPlugin]] = None,
tail_logs: bool = False,
dependencies: Optional[list[str]] = None,
) -> str:
if isinstance(task, Script):
default_name = task.get_name()
else:
default_name = get_type_namespace(task.__fn_or_cls__)
reuse_job_dir = True if name else False
name = name or default_name
if any(map(lambda job: job.id == name, self.jobs)):
task_id = f"{name}_{len(self.jobs)}"
else:
task_id = name
self._validate_task(task_info=task_id, task=task)
executor = executor.clone()
executor.assign(
self._id,
self._exp_dir,
task_id=task_id,
task_dir=name if reuse_job_dir else task_id,
)
cloned = copy.deepcopy(task) if isinstance(task, Script) else task.clone()
job = Job(
id=task_id,
task=cloned,
executor=executor,
plugins=plugins,
tail_logs=tail_logs,
dependencies=dependencies or [],
)
plugins = plugins or []
for plugin in plugins:
plugin.assign(self._id)
plugin.setup(cloned, executor)
self._jobs.append(job)
return job.id
def _add_job_group(
self,
tasks: list[Partial | Script],
executor: list[Executor] | Executor,
name: str,
plugins: Optional[list[ExperimentPlugin]] = None,
tail_logs: bool = False,
dependencies: Optional[list[str]] = None,
) -> str:
if any(map(lambda task: task.id == name, self.jobs)):
task_id = f"{name}_{len(self.jobs)}"
else:
task_id = name
for i, _task in enumerate(tasks):
self._validate_task(task_info=f"Job Group: {task_id}, job index: {i}", task=_task)
executors = executor if isinstance(executor, list) else [executor]
cloned_executors = []
for executor in executors:
new_executor = executor.clone()
cloned_executors.append(new_executor)
new_executor.assign(self._id, self._exp_dir, task_id, task_dir=name)
cloned_tasks = []
for task in tasks:
cloned_task = copy.deepcopy(task) if isinstance(task, Script) else task.clone()
cloned_tasks.append(cloned_task)
job_group = JobGroup(
id=task_id,
tasks=cloned_tasks,
executors=cloned_executors,
plugins=plugins,
tail_logs=tail_logs,
dependencies=dependencies or [],
)
plugins = plugins or []
for plugin in plugins:
for i, task in enumerate(cloned_tasks):
_executor = job_group.executors if job_group._merge else job_group.executors[i] # type: ignore
assert isinstance(_executor, Executor)
plugin.setup(task, _executor)
self._jobs.append(job_group)
return job_group.id
def _validate_task(self, task_info: str, task: Union[Partial, Script]) -> None:
valid = True
message = ""
if isinstance(task, Partial):
serializer = ZlibJSONSerializer()
serialized = serializer.serialize(task)
deserialized = serializer.deserialize(serialized)
diff = diffing.build_diff(deserialized, task)
diff = {
daglish.path_str(d.target): (d.new_value if hasattr(d, "new_value") else None) # type: ignore
for d in diff.changes
}
if deserialized != task:
valid = False
message += f"""
Deserialized task does not match original task. The following paths in your task need to be wrapped in `run.Config` or `run.Partial`:
{pprint.PrettyPrinter(indent=4).pformat(diff)}
For more information about `run.Config` and `run.Partial`, please refer to https://github.com/NVIDIA-NeMo/Run/blob/main/docs/source/guides/configuration.md.
"""
if not valid:
raise RuntimeError(f"Failed to validate task {task_info}.\n{message}")
[docs]
def add(
self,
task: Union[Partial, Script] | list[Union[Partial, Script]],
executor: Executor | list[Executor] | None = None,
name: str = "",
plugins: Optional[list[ExperimentPlugin]] = None,
tail_logs: bool = False,
dependencies: Optional[list[str]] = None,
) -> str:
"""
Add a configured function along with its executor config to the experiment.
"""
assert _current_experiment.get(None) == self, (
"Using Experiment without it's context manager is not permitted."
)
job_ids = set([job.id for job in self.jobs])
for dep in dependencies or []:
assert dep in job_ids, f"Dependency {dep} not found."
executor = executor or self.executor
if not isinstance(task, list):
assert executor and isinstance(executor, Executor)
job_id = self._add_single_job(
task,
executor,
name,
plugins=plugins,
tail_logs=tail_logs,
dependencies=dependencies.copy() if dependencies else None,
)
else:
assert name, "name is required for task group."
job_id = self._add_job_group(
task,
executor,
name,
plugins=plugins,
tail_logs=tail_logs,
dependencies=dependencies.copy() if dependencies else None,
)
return job_id
[docs]
def dryrun(self, log: bool = True, exist_ok: bool = False, delete_exp_dir: bool = True):
"""
Logs the raw scripts that will be executed for each task.
"""
if log:
self.console.log(f"[bold magenta]Experiment {self._id} dryrun...")
self._prepare(exist_ok=exist_ok)
for job in self.jobs:
if isinstance(job, Job):
if log:
self.console.log(f"[bold magenta]Task {job.id}\n")
elif isinstance(job, JobGroup):
if log:
self.console.log(f"[bold magenta]Task Group {job.id}\n")
job.launch(wait=False, runner=self._runner, dryrun=True, direct=False, log_dryrun=log)
if delete_exp_dir:
shutil.rmtree(self._exp_dir)
[docs]
def run(
self,
sequential: bool = False,
detach: bool = False,
tail_logs: bool = False,
direct: bool = False,
):
"""
Runs all the tasks in the experiment.
By default, all tasks are run in parallel.
If sequential=True, all tasks will be run one after the other.
The order is based on the order in which they were added.
Parallel mode only works if all executors in the experiment support it.
Currently, all executors support parallel mode.
In sequential mode, if all executor supports dependencies, then all tasks will be scheduled at once
by specifying the correct dependencies to each task.
Otherwise, the experiment.run call will block and each task that is scheduled will be executed sequentially.
In this particular case, we cannot guarantee the state of the exeperiment if the process exits in the middle.
Currently, only the slurm executor supports dependencies.
Args:
sequential: If True, runs all tasks sequentially in the order they were added. Defaults to False.
detach: If True, detaches from the process after launching the tasks. Only supported for Slurm and Skypilot. Defaults to False.
tail_logs: If True, tails logs from all tasks in the experiment. If False, relies on task specific setting. Defaults to False.
direct: If True, runs all tasks in the experiment sequentially in the same process. Note that if direct=True, then sequential also will be True. Defaults to False.
"""
assert _current_experiment.get(None) == self, (
"Using Experiment without it's context manager is not permitted."
)
if self._launched:
self.console.log("[bold magenta]Experiment already running...")
return
if self._reconstruct:
self.console.log("[bold magenta]Experiment in inspection mode...")
return
# Prepare experiment before running
# in case of multi-node execution with LocalExecutor+torchrun+slurm, run only on first rank
if int(os.getenv("SLURM_PROCID", 0)) == 0:
self._prepare()
if direct:
self.console.log(
"[bold magenta]Running the experiment with direct=True. "
"This will launch all jobs sequentially in the same process."
)
if not self.jobs:
self.console.log("[bold red]No jobs to run in this experiment.")
return
assert all(map(lambda job: isinstance(job, Job), self.jobs)), (
"Jobs in this experiment contain JobGroup which cannot be run directly for now."
)
assert all(map(lambda job: not job.dependencies, self.jobs)), (
"Jobs in this experiment contain dependencies which cannot be run directly for now."
)
for job in self.jobs:
assert isinstance(job, Job)
with TeeStdoutStderr(
os.path.join(job.executor.job_dir, f"log_{job.id}_direct_run.out")
):
job.launch(wait=True, direct=True, runner=self._runner)
self._save_jobs()
self._launched = any(map(lambda job: job.launched, self.jobs))
self._direct = True
return
executors = set()
for job in self.jobs:
if isinstance(job, Job):
executors.add(job.executor.__class__)
elif isinstance(job, JobGroup):
if isinstance(job.executors, list):
for executor in job.executors:
executors.add(executor.__class__)
else:
executors.add(job.executors.__class__)
if detach and any(map(lambda x: x not in self._DETACH_SUPPORTED_EXECUTORS, executors)):
self.console.log(
"[bold red] Cannot detach from this experiment. Please keep it running until completion."
)
detach = False
is_dag = any(map(lambda job: len(job.dependencies) > 0, self.jobs))
assert not (is_dag and sequential), (
"Jobs in this experiment have dependencies, they cannot be run sequentially. Set sequential=False."
)
if sequential:
for i in range(1, len(self.jobs)):
self.jobs[i].dependencies.append(self.jobs[i - 1].id)
self.dryrun(log=False, exist_ok=True, delete_exp_dir=False)
for tunnel in self.tunnels.values():
if isinstance(tunnel, SSHTunnel):
tunnel.connect()
assert tunnel.session, f"SSH tunnel {tunnel.key} failed to connect."
rsync(tunnel.session, self._exp_dir, os.path.dirname(tunnel.job_dir))
symlink_cmds = []
for packaging_job in tunnel.packaging_jobs.values():
if packaging_job.symlink:
symlink_cmds.append(packaging_job.symlink_cmd())
if symlink_cmds:
tunnel.run(" && ".join(symlink_cmds))
self._save_tunnels()
return self._run_dag(detach=detach, tail_logs=tail_logs, executors=executors)
def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):
job_map = {job.id: job for job in self._jobs}
graph = nx.DiGraph()
job_ids = set([job.id for job in self.jobs])
for job in self.jobs:
graph.add_node(job.id, job=job)
for dep in job.dependencies:
assert dep in job_ids, f"Dependency {dep} not found in job list {job_ids}."
graph.add_edge(dep, job.id)
assert nx.is_directed_acyclic_graph(graph), "Jobs have cyclic dependencies."
order = [sorted(generation) for generation in nx.topological_generations(graph)]
add_deps = False
if len(order) > 1:
if all(map(lambda x: x in self._DEPENDENCY_SUPPORTED_EXECUTORS, executors)):
wait = False
add_deps = True
self.detach = detach
else:
wait = True
if len(self.jobs) > 1:
self.console.log(
f"[bold cyan]Dependencies not supported for atleast one of {executors}."
"All jobs will be run one after the other based on their dependencies, please keep the process alive."
)
if detach:
self.console.log(
"[bold red] Cannot detach from this experiment. Please keep it running until completion."
)
else:
# All jobs will be executed in parallel
assert all(map(lambda x: x in self._PARALLEL_SUPPORTED_EXECUTORS, executors)), (
f"Parallel mode not supported for atleast one of {executors}. Set sequential=True."
)
wait = False
self.detach = detach
for level in order:
# Launch jobs in this level concurrently since they are independent
def _set_context(ctx: contextvars.Context):
for var, value in ctx.items():
var.set(value)
ctx = contextvars.copy_context()
def _launch(node: str):
job: Job | JobGroup = job_map[node]
self.console.log(f"[bold cyan]Launching job {job.id} for experiment {self._title}")
if tail_logs:
job.tail_logs = True
try:
if add_deps:
deps = []
for dep_id in job.dependencies:
dep = job_map[dep_id]
handle = dep.handle
assert dep.launched and handle, (
f"Dependency {dep.id} for {job.id} not yet launched."
)
deps.append(handle)
job.executor.dependencies = deps # type: ignore
job.launch(wait=False, runner=self._runner)
return job
except Exception as e:
self.console.log(f"Error running job {job.id}: {e}")
raise e
launched_jobs: list[Job | JobGroup] = []
with ThreadPoolExecutor(
initializer=_set_context, initargs=(ctx,), max_workers=self._threadpool_workers
) as pool:
futures = [pool.submit(_launch, node) for node in level]
for future in as_completed(futures):
launched_jobs.append(future.result())
if wait:
self._wait_for_jobs(jobs=launched_jobs)
self._save_jobs()
self._launched = any(map(lambda job: job.launched, self.jobs))
self._waited = wait
def _wait_for_jobs(self, jobs: list[Job | JobGroup]):
def set_context(context: contextvars.Context):
for var, value in context.items():
var.set(value)
context = contextvars.copy_context()
with ThreadPoolExecutor(initializer=set_context, initargs=(context,)) as executor:
futures: dict[Future, Job | JobGroup] = {}
for job in jobs:
if isinstance(job, Job):
handle_exists = job.handle
else:
handle_exists = len(job.handles) > 0 and all(job.handles)
if job.launched and handle_exists:
self._initialize_live_progress()
self._add_progress(job=job)
future = executor.submit(
job.wait,
runner=self._runner
if isinstance(
job.executor,
self._RUNNER_DEPENDENT_EXECUTORS,
)
else get_runner(),
)
futures[future] = job
for future in as_completed(futures.keys()):
job = futures[future]
try:
future.result()
self._update_progress(job, job.state)
except Exception as e:
self.console.log(f"Exception while waiting for Job {job.id}: {e}")
self.console.log(*traceback.format_exception(e))
self._update_progress(job, AppState.UNKNOWN)
finally:
job.cleanup()
def _initialize_tunnels(self, extract_from_executors: bool = False):
if extract_from_executors:
for job in self.jobs:
if (
isinstance(job.executor, SlurmExecutor)
and job.executor.tunnel.key not in self.tunnels
):
self.tunnels[job.executor.tunnel.key] = job.executor.tunnel
for tunnel in self.tunnels.values():
if isinstance(tunnel, SSHTunnel):
tunnel.connect()
assert tunnel.session, f"SSH tunnel {tunnel.key} failed to connect."
[docs]
def status(self, return_dict: bool = False) -> Optional[dict[str, dict[str, str]]]:
"""
Prints a table specifying the status of all tasks.
.. note::
status is not supported for local executor
and the status for a task using the local executor
will be listed as UNKNOWN in most cases
"""
_set_current_experiment = False
if not self._current_experiment_token:
_current_experiment.set(self)
_set_current_experiment = True
def _get_job_info_and_dict(
idx: int, job: Job | JobGroup
) -> tuple[list[str], dict[str, str]]:
job_info = []
job_info.append(f"[bold green]Task {idx}[/bold green]: [bold orange1]{job.id}")
job_info.append(
f"- [bold green]Status[/bold green]: {str(job.status(runner=self._runner))}"
)
job_info.append(f"- [bold green]Executor[/bold green]: {job.executor.info()}")
try:
_, _, path_str = job.handle.partition("://")
path = path_str.split("/")
app_id = path[1]
except Exception:
app_id = ""
job_info.append(f"- [bold green]Job id[/bold green]: {app_id}")
directory_info = [
"- [bold green]Local Directory[/bold green]: " + job.executor.job_dir,
]
job_dict = {
"name": job.id,
"status": job.status(runner=self._runner),
"executor": job.executor.info(),
"job_id": app_id,
"handle": job.handle,
"local_dir": job.executor.job_dir,
}
if isinstance(job.executor, SlurmExecutor) and isinstance(
job.executor.tunnel, SSHTunnel
):
directory_info.extend(
[
"- [bold green]Remote Directory[/bold green]: "
+ os.path.join(
job.executor.tunnel.job_dir,
Path(job.executor.job_dir).name,
),
]
)
job_dict["remote_dir"] = os.path.join(
job.executor.tunnel.job_dir,
Path(job.executor.job_dir).name,
)
job_info.extend(directory_info)
return job_info, job_dict
self._initialize_tunnels(extract_from_executors=True)
try:
result_dict = {}
job_infos: list[Group | None] = [None] * len(self.jobs)
# Parallelize IO-bound status retrieval across jobs
def _collect(arg):
idx, job = arg
job_info, job_dict = _get_job_info_and_dict(idx, job)
return idx, job.id, job_info, job_dict
# Propagate context variables to worker threads so helpers that rely on them keep working
def _set_context(ctx: contextvars.Context):
for var, value in ctx.items():
var.set(value)
ctx = contextvars.copy_context()
with ThreadPoolExecutor(
initializer=_set_context, initargs=(ctx,), max_workers=self._threadpool_workers
) as pool:
futures = [pool.submit(_collect, (idx, job)) for idx, job in enumerate(self.jobs)]
for future in as_completed(futures):
idx, job_id, job_info, job_dict = future.result()
job_infos[idx] = Group(*job_info)
result_dict[job_id] = job_dict
# Remove potential None slots (should not occur)
job_infos = [ji for ji in job_infos if ji is not None]
if return_dict:
return result_dict
self.console.print()
self.console.print(
f"[bold green]Experiment Status for[/bold green] [bold orange1]{self._id}",
new_line_start=True,
)
for job_info in job_infos:
self.console.print(job_info, soft_wrap=True, new_line_start=True, highlight=False)
self.console.print()
finally:
if _set_current_experiment and self._current_experiment_token:
_current_experiment.reset(self._current_experiment_token)
self._current_experiment_token = None
[docs]
def cancel(self, job_id: str):
"""
Cancels an existing job if still running.
"""
_set_current_experiment = False
if not self._current_experiment_token:
_current_experiment.set(self)
_set_current_experiment = True
self.console.log(f"[bold cyan]Cancelling {job_id} if still running")
try:
job = next(filter(lambda x: x.id == job_id, self.jobs))
job.cancel(runner=self._runner)
except StopIteration:
self.console.log(f"[bold red]Job {job_id} not found")
except Exception as e:
self.console.log(f"[bold red]Failed to cancel {job_id}\nError: {e}\n")
self.console.log(*traceback.format_exception(e))
finally:
if _set_current_experiment and self._current_experiment_token:
_current_experiment.reset(self._current_experiment_token)
self._current_experiment_token = None
[docs]
def logs(self, job_id: str, regex: str | None = None):
"""
Prints the logs of the specified job_id, optionally filtered by regex.
"""
_set_current_experiment = False
if not self._current_experiment_token:
_current_experiment.set(self)
_set_current_experiment = True
self.console.log(f"[bold cyan]Fetching logs for {job_id}")
try:
job = next(filter(lambda x: x.id == job_id, self.jobs))
if isinstance(job, Job) and job.handle.endswith("direct_run"):
self.console.log("This job was run with direct=True.")
self.console.log(
f"Logs may be present in task directory at:\n[bold]{job.executor.job_dir}."
)
return
try:
job.logs(runner=self._runner, regex=regex)
except Exception as e:
self.console.log(f"[bold red]Failed to get logs for {job_id}\nError: {e}\n")
self.console.log(
f"Logs may be present in job directory at:\n[bold]{job.executor.job_dir}."
)
except StopIteration:
self.console.log(f"[bold red]Job {job_id} not found")
finally:
if _set_current_experiment and self._current_experiment_token:
_current_experiment.reset(self._current_experiment_token)
self._current_experiment_token = None
[docs]
def reset(self) -> "Experiment":
"""
Resets an experiment to make it ready for a relaunch.
Only works if the current experiment run has already been launched.
"""
if not self._reconstruct and not os.path.isfile(
os.path.join(self._exp_dir, self._DONE_FILE)
):
self.console.log(
f"[bold magenta]Experiment {self._id} has not run yet, skipping reset..."
)
return self
old_id, old_exp_dir, old_launched = self._id, self._exp_dir, self._launched
self._id = f"{self._title}_{int(time.time())}"
self._exp_dir = os.path.join(get_nemorun_home(), "experiments", self._title, self._id)
self._launched = False
self._live_progress = None
jobs = self._jobs
self._jobs = []
serializer = ZlibJSONSerializer()
_set_current_experiment = False
if not self._current_experiment_token:
_current_experiment.set(self)
_set_current_experiment = True
try:
if "__external_main__" not in sys.modules:
maybe_load_external_main(old_exp_dir)
for job in jobs:
if isinstance(job, Job):
if isinstance(job.task, str):
_task = serializer.deserialize(job.task)
if _task.__fn_or_cls__ == Script:
job.task = fdl.build(_task)
else:
job.task = _task # type: ignore
self.add(
job.task,
job.executor,
name=job.id,
tail_logs=job.tail_logs,
)
else:
if isinstance(job.tasks, str):
tasks = serializer.deserialize(job.tasks)
job.tasks = [
fdl.build(task) if task.__fn_or_cls__ == Script else task
for task in tasks
]
self.add(
job.tasks,
job.executors,
name=job.id,
tail_logs=job.tail_logs,
)
except Exception as e:
self.console.log(
f"[bold magenta]Failed resetting Experiment {self._id} due to error: {e}"
)
# Double check exp dir is unchanged
new_path = os.path.join(get_nemorun_home(), "experiments", self._title, self._id)
if self._exp_dir == new_path and new_path != old_exp_dir:
shutil.rmtree(self._exp_dir)
self._id = old_id
self._exp_dir = old_exp_dir
self._launched = old_launched
self._jobs = self._load_jobs()
finally:
if _set_current_experiment and self._current_experiment_token:
_current_experiment.reset(self._current_experiment_token)
self._current_experiment_token = None
self._reconstruct = False
return self
def _initialize_live_progress(self):
if not self._live_progress:
# Disable live progress if we are tailing logs for any task
# as tty output consistency can not be guaranteed as of now
if self.clean_mode or any(map(lambda job: job.tail_logs, self.jobs)):
return
assert isinstance(self.console, rich.console.Console)
self._progress = Progress(
"{task.description}",
SpinnerColumn(),
BarColumn(bar_width=None),
TimeElapsedColumn(),
)
self._exp_panel = Panel(
self._progress,
title=f"[b]{self._id}",
padding=(1, 3),
)
self._task_progress: dict[str, TaskID] = {}
self._live_progress = Live(self._exp_panel, console=self.console, refresh_per_second=10)
self._live_progress.start(refresh=True)
def _add_progress(self, job: Job | JobGroup):
if self._live_progress:
self._task_progress[job.id] = self._progress.add_task(
f"[bold green]{job.id}", total=None
)
def _update_progress(self, job: Job | JobGroup, state: AppState):
if self._live_progress:
color = "[bold green]" if state == AppState.SUCCEEDED else "[bold red]"
task_progress_id = self._task_progress[job.id]
self._progress.stop_task(task_progress_id)
self._progress.update(
task_progress_id,
description=f"{color}{job.id} {state}",
)
progress_task: RichTask = self._progress._tasks[task_progress_id]
progress_task.finished_time = progress_task.elapsed
progress_task.completed = progress_task.elapsed or 0.0
progress_task.total = progress_task.elapsed
self._progress.refresh()
def _cleanup(self, tunnels: bool = True):
if tunnels and hasattr(self, "tunnels"):
for tunnel in self.tunnels.values():
try:
tunnel.cleanup()
except Exception:
...
self._runner.close()
if (
_current_experiment is not None
and _current_experiment.get(None)
and self._current_experiment_token
):
_current_experiment.reset(self._current_experiment_token)
self._current_experiment_token = None
def __enter__(self) -> "Experiment":
self._current_experiment_token = _current_experiment.set(self)
self.console.rule(
f"[bold magenta]Entering Experiment {self._title} with id: {self._id}",
)
return self
def __exit__(self, exc_type, exc_value, tb):
try:
if hasattr(self, "detach") and self.detach:
self.console.rule(f"[bold magenta]Detaching from Experiment {self._id}.")
self.console.log(
"Task specific cleanup won't be run.\n"
"Ephemeral logs and artifacts may be lost.",
)
if self._launched and not self._skip_status_at_exit:
self.status()
return
if self._launched:
if hasattr(self, "_direct") and self._direct:
self.console.rule(
f"[bold magenta]Direct run Experiment {self._id}",
)
if not self._skip_status_at_exit:
self.status()
return
if hasattr(self, "_waited") and self._waited:
self.console.rule(
f"[bold magenta]Done waiting for Experiment {self._id}",
)
if not self._skip_status_at_exit:
self.status()
return
self.console.rule(
f"[bold magenta]Waiting for Experiment {self._id} to finish",
)
if not self._skip_status_at_exit:
self.status()
self._wait_for_jobs(jobs=self.jobs)
finally:
if self._live_progress:
self._live_progress.stop()
self._cleanup(tunnels=False)
if self._launched:
Path(os.path.join(self._exp_dir, self._DONE_FILE)).touch()
if self._enable_goodbye_message:
self.console.print(
Syntax(
self.GOODBYE_MESSAGE_PYTHON.format(
exp_id=self._id,
tasks=list(map(lambda job: job.id, self.jobs)),
),
"python",
theme=os.environ.get("NEMO_RUN_CODE_THEME", "monokai"),
)
)
self.console.print(
Syntax(
self.GOODBYE_MESSAGE_BASH.format(
exp_id=self._id,
tasks=list(map(lambda job: job.id, self.jobs)),
),
"shell",
theme=os.environ.get("NEMO_RUN_CODE_THEME", "monokai"),
)
)
def _repr_svg_(self):
return self.to_config()._repr_svg_()
def __del__(self):
try:
deconfigure_logging()
self._cleanup()
except Exception:
pass
@property
def jobs(self) -> list[Job | JobGroup]:
return Jobs(self._jobs)
@jobs.setter
def jobs(self, jobs: list[Job | JobGroup]):
self._jobs = jobs
@property
def tasks(self) -> list[Config]:
serializer = ZlibJSONSerializer()
for job in self._jobs:
if isinstance(job, Job):
if isinstance(job.task, str):
_task = serializer.deserialize(job.task)
if _task.__fn_or_cls__ == Script:
job.task = fdl.build(_task)
else:
job.task = _task # type: ignore
else:
if isinstance(job.tasks, str):
tasks = serializer.deserialize(job.tasks)
job.tasks = [
fdl.build(task) if task.__fn_or_cls__ == Script else task for task in tasks
]
return Tasks((job.task if isinstance(job, Job) else job.tasks) for job in self._jobs)
class Tasks(list, ConfigurableMixin): ...
class Jobs(list, ConfigurableMixin): ...
def _get_latest_dir(path) -> str:
dirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
latest_dir = max(dirs, key=lambda d: os.path.getctime(os.path.join(path, d)))
return os.path.join(path, latest_dir)
def _get_sorted_dirs(path: str) -> list[str]:
if not os.path.exists(path):
return []
dirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
dirs = sorted(dirs, key=lambda d: os.path.getctime(os.path.join(path, d)))
return list(dirs)
_LOADED_MAINS = set()
def maybe_load_external_main(exp_dir: str):
main_file = Path(exp_dir) / "__main__.py"
if main_file.exists() and main_file not in _LOADED_MAINS:
_LOADED_MAINS.add(main_file)
spec = importlib.util.spec_from_file_location("__external_main__", main_file)
if spec is not None and spec.loader is not None:
new_main_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(new_main_module)
if "__external_main__" not in sys.modules:
sys.modules["__external_main__"] = new_main_module
else:
external = sys.modules["__external_main__"]
for attr in dir(new_main_module):
if not attr.startswith("__"):
setattr(external, attr, getattr(new_main_module, attr))
existing_main = sys.modules["__main__"]
for attr in dir(new_main_module):
if not attr.startswith("__"):
setattr(existing_main, attr, getattr(new_main_module, attr))