Important

NeMo 2.0 is an experimental feature and currently released in the dev container only: nvcr.io/nvidia/nemo:dev. Please refer to NeMo 2.0 overview for information on getting started.

Quickstart with NeMo-Run

This is an introduction to running any of the supported NeMo 2.0 Recipes using NeMo-Run. In this tutorial, we will take a pretraining and finetuning recipe and try to run it locally, as well as remotely, on a Slurm-based cluster. Let’s get started.

Please go through the NeMo-Run README to get a high-level overview of NeMo-Run.

Pretraining

For the purposes of this pretraining quickstart, we will use a relatively small model. We will begin with the NeMotron3 4b pretraining recipe, and go through the steps required to configure and launch pretraining.

All steps are being run from inside the NeMo dev container (nvcr.io/nvidia/nemo:dev). This tutorial was run on a node with 2 GPUs (each RTX 5880 with 48GB memory). Please change the configuration to match your host. For example, you can reduce num_layers or hidden_size in the model config to make it fit on a single GPU.

Set up the Prerequisites

Run the following commands to set up your workspace and files:

# Check GPU access
nvidia-smi

# Create and go to workspace
mkdir -p /workspace/nemo-run
cd /workspace/nemo-run

# Create a python file to run pre-training
touch nemotron_pretraining.py

Configure the Recipe

Configure the recipe inside nemotron_pretraining.py:

import nemo_run as run

from nemo.collections import llm


def configure_recipe(nodes: int = 1, gpus_per_node: int = 2):
    recipe = llm.nemotron3_4b.pretrain_recipe(
        dir="/checkpoints/nemotron", # Path to store checkpoints
        name="nemotron_pretraining",
        tensor_parallelism=2,
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
        max_steps=100, # Setting a small value for the quickstart
    )

    # Add overrides here

    return recipe

Here, the recipe variable holds a configured run.Partial object. Please read about the configuration system in NeMo-Run here for more details. For those familiar with the NeMo 1.0-style YAML configuration, this recipe is just a Pythonic version of a YAML config file for pretraining.

Override attributes

You can set overrides on its attributes like normal Python objects. So, if want to change the val_check_interval, you can override it after defining your recipe by setting:

recipe.trainer.val_check_interval = 100

Note

An important thing to remember is that you are only configuring your task at this stage; the underlying code is not being executed at this time.

Swap Recipes

The recipes in NeMo 2.0 are easily swappable. For instance, if you want to swap the NeMotron recipe with a LLaMA 3 recipe, you can simply run the following command:

recipe = llm.llama3_8b.pretrain_recipe(
    dir="/checkpoints/llama3", # Path to store checkpoints
    name="llama3_pretraining",
    num_nodes=nodes,
    num_gpus_per_node=gpus_per_node,
)

Once you have the final recipe configured, you are ready to move to the execution stage.

Execute Locally

  1. First, we will execute locally using torchrun. In order to do that, we will define a LocalExecutor as shown:

def local_executor_torchrun(nodes: int = 1, devices: int = 2) -> run.LocalExecutor:
    # Env vars for jobs are configured here
    env_vars = {
        "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
        "NCCL_NVLS_ENABLE": "0",
        "NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
        "NVTE_ASYNC_AMAX_REDUCTION": "1",
        "NVTE_FUSED_ATTN": "0",
    }

    executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)

    return executor

To find out more about NeMo-Run executors, see the execution guide.

  1. Next, we will combine the recipe and executor to launch the pretraining run:

def run_pretraining():
    recipe = configure_recipe()
    executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)

    run.run(recipe, executor=executor)

# Wrap the call in an if __name__ == "__main__": block to work with Python's multiprocessing module.
if __name__ == "__main__":
    run_pretraining()

The full code for nemotron_pretraining.py looks like:

import nemo_run as run

from nemo.collections import llm


def configure_recipe(nodes: int = 1, gpus_per_node: int = 2):
    recipe = llm.nemotron3_4b.pretrain_recipe(
        dir="/checkpoints/nemotron", # Path to store checkpoints
        name="nemotron_pretraining",
        tensor_parallelism=2,
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
        max_steps=100, # Setting a small value for the quickstart
    )

    recipe.trainer.val_check_interval = 100
    return recipe

def local_executor_torchrun(nodes: int = 1, devices: int = 2) -> run.LocalExecutor:
    # Env vars for jobs are configured here
    env_vars = {
        "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
        "NCCL_NVLS_ENABLE": "0",
        "NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
        "NVTE_ASYNC_AMAX_REDUCTION": "1",
        "NVTE_FUSED_ATTN": "0",
    }

    executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)

    return executor

def run_pretraining():
    recipe = configure_recipe()
    executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)

    run.run(recipe, executor=executor)

# This condition is necessary for the script to be compatible with Python's multiprocessing module.
if __name__ == "__main__":
    run_pretraining()

You can run the file using just:

python nemotron_pretraining.py

Here’s a recording showing all the steps above leading up to the start of pretraining:

Change the number of GPUs

Let’s see how we can change the configuration to run on just 1 GPU instead of 2. All you need to do is change the configuration in run_pretraining, as shown below:

def run_pretraining():
    recipe = configure_recipe()
    executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)

    # Change to 1 GPU

    # Change executor params
    executor.ntasks_per_node = 1
    executor.env_vars["CUDA_VISIBLE_DEVICES"] = "0"

    # Change recipe params

    # The default number of layers comes from the recipe in nemo where num_layers is 32
    # Ref: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/llm/gpt/model/nemotron.py
    # To run on 1 GPU without TP, we can reduce the number of layers to 8 by setting recipe.model.config.num_layers = 8
    recipe.model.config.num_layers = 8
    # We also need to set TP to 1, since we had used 2 for 2 GPUs.
    recipe.trainer.strategy.tensor_model_parallel_size = 1
    # Lastly, we need to set devices to 1 in the trainer.
    recipe.trainer.devices = 1

    run.run(recipe, executor=executor)

Execute on a Slurm Cluster

One of the benefits of NeMo-Run is to allow you to easily scale from local to remote slurm-based clusters. Next, let’s see how we can launch the same pretraining recipe on a Slurm cluster.

  1. First, we’ll define a slurm executor:

def slurm_executor(
    user: str,
    host: str,
    remote_job_dir: str,
    account: str,
    partition: str,
    nodes: int,
    devices: int,
    time: str = "01:00:00",
    custom_mounts: Optional[list[str]] = None,
    custom_env_vars: Optional[dict[str, str]] = None,
    container_image: str = "nvcr.io/nvidia/nemo:dev",
    retries: int = 0,
) -> run.SlurmExecutor:
    if not (user and host and remote_job_dir and account and partition and nodes and devices):
        raise RuntimeError(
            "Please set user, host, remote_job_dir, account, partition, nodes and devices args for using this function."
        )

    mounts = []
    # Custom mounts are defined here.
    if custom_mounts:
        mounts.extend(custom_mounts)

    # Env vars for jobs are configured here
    env_vars = {
        "TRANSFORMERS_OFFLINE": "1",
        "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
        "NCCL_NVLS_ENABLE": "0",
        "NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
        "NVTE_ASYNC_AMAX_REDUCTION": "1",
        "NVTE_FUSED_ATTN": "0",
    }
    if custom_env_vars:
        env_vars |= custom_env_vars

    # This defines the slurm executor.
    # We connect to the executor via the tunnel defined by user, host and remote_job_dir.
    executor = run.SlurmExecutor(
        account=account,
        partition=partition,
        tunnel=run.SSHTunnel(
            user=user,
            host=host,
            job_dir=remote_job_dir, # This is where the results of the run will be stored by default.
            # identity="/path/to/identity/file" OPTIONAL: Provide path to the private key that can be used to establish the SSH connection without entering your password.
        ),
        nodes=nodes,
        ntasks_per_node=devices,
        gpus_per_node=devices,
        mem="0",
        exclusive=True,
        gres="gpu:8",
        packager=run.Packager(),
    )

    executor.container_image = container_image
    executor.container_mounts = mounts
    executor.env_vars = env_vars
    executor.retries = retries
    executor.time = time

    return executor
  1. Next, you can just replace the local executor with the slurm executor, like below:

def run_pretraining_with_slurm():
    recipe = configure_recipe(nodes=1, gpus_per_node=8)
    executor = slurm_executor(
        user="", # TODO: Set the username you want to use
        host="", # TODO: Set the host of your cluster
        remote_job_dir="", # TODO: Set the directory on the cluster where you want to save results
        account="", # TODO: Set the account for your cluster
        partition="", # TODO: Set the partition for your cluster
        container_image="", # TODO: Set the container image you want to use for your job
        # container_mounts=[], TODO: Set any custom mounts
        # custom_env_vars={}, TODO: Set any custom env vars
        nodes=recipe.trainer.num_nodes,
        devices=recipe.trainer.devices,
    )

    run.run(recipe, executor=executor, detach=True)
  1. Finally, you can run it as follows:

if __name__ == "__main__":
    run_pretraining_with_slurm()
python nemotron_pretraining.py

Since we have set detach=True, the process will exit after scheduling the job on the cluster with information about directories and commands to manage the run/experiment.