Quickstart with NeMo-Run#

This tutorial explains how to run any of the supported NeMo 2.0 recipes using NeMo-Run. We will demonstrate how to run a pretraining and fine-tuning recipe both locally and remotely on a Slurm-based cluster. Let’s get started!

For a high-level overview of NeMo-Run, please refer to the NeMo-Run README.

Minimum Requirements#

This tutorial requires a minimum of 1 NVIDIA GPU with 48GB of memory for finetuning and 2 NVIDIA GPUs with 48GB of memory each for pretraining. Pretraining can also be done on a single GPU or GPUs with less memory by decreasing the model size. Each section can be followed individually based on your needs. You will also need to run this tutorial inside the NeMo container with the dev tag.

You can launch the NeMo container using the following command:

docker run --rm -it --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 nvcr.io/nvidia/nemo:dev

Pretraining#

Note

The default pretraining recipe uses the MockDataModule. If you want to use a real dataset, follow the instructions here.

For this pretraining quickstart, we will use a relatively small model. We will begin with the Nemotron 3 4B pretraining recipe and go through the steps required to configure and launch pretraining.

As mentioned in the requirements, this tutorial was run on a node with 2 GPUs (each RTX 5880 with 48GB of memory). If you intend to run it on just 1 GPU or GPUs with less memory, please adjust the configuration to match your host. For example, you can reduce num_layers or hidden_size in the model config to make it fit on a single GPU.

Set Up the Prerequisites#

Run the following commands to set up your workspace and files:

# Check GPU access
nvidia-smi

# Create and go to workspace
mkdir -p /workspace/nemo-run
cd /workspace/nemo-run

# Create a python file to run pre-training
touch nemotron_pretraining.py

Configure the Recipe#

Important

In any script you write, please make sure you wrap your code in an if __name__ == "__main__": block. See Working with scripts in NeMo 2.0 for details.

Configure the recipe inside nemotron_pretraining.py:

import nemo_run as run

from nemo.collections import llm


def configure_recipe(nodes: int = 1, gpus_per_node: int = 2):
    recipe = llm.nemotron3_4b.pretrain_recipe(
        dir="/checkpoints/nemotron", # Path to store checkpoints
        name="nemotron_pretraining",
        tensor_parallelism=2,
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
        max_steps=100, # Setting a small value for the quickstart
    )

    # Add overrides here

    return recipe

Here, the recipe variable holds a configured run.Partial object. For those familiar with the NeMo 1.0-style YAML configuration, this recipe is just a Pythonic version of a YAML config file for pretraining.

Note

The configuration in the recipes is done using the NeMo-Run run.Config and run.Partial configuration objects. Please review the NeMo-Run documentation to learn more about its configuration and execution system.

Override the Attributes#

You can override its attributes just like you would with any normal Python object. For example, if you want to change the val_check_interval, you can do so after defining your recipe by setting:

recipe.trainer.val_check_interval = 100

Note

An important thing to remember is that you are only configuring your task at this stage; the underlying code is not being executed at this time.

Swap Recipes#

The recipes in NeMo 2.0 are easily swappable. For instance, if you want to swap the NeMotron recipe with a Llama 3 recipe, you can simply run the following command:

recipe = llm.llama3_8b.pretrain_recipe(
    dir="/checkpoints/llama3", # Path to store checkpoints
    name="llama3_pretraining",
    num_nodes=nodes,
    num_gpus_per_node=gpus_per_node,
)

Once you have the final recipe configured, you are ready to move to the execution stage.

Execute Locally#

  1. Execute locally using torchrun. To do so, we will define a LocalExecutor as shown:

def local_executor_torchrun(nodes: int = 1, devices: int = 2) -> run.LocalExecutor:
    # Env vars for jobs are configured here
    env_vars = {
        "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
        "NCCL_NVLS_ENABLE": "0",
        "NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
        "NVTE_ASYNC_AMAX_REDUCTION": "1",
    }

    executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)

    return executor

To find out more about NeMo-Run executors, see the execution guide.

  1. Combine the recipe and executor to launch the pretraining run:

def run_pretraining():
    recipe = configure_recipe()
    executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)

    run.run(recipe, executor=executor, name="nemotron3_4b_pretraining")

# Wrap the call in an if __name__ == "__main__": block to work with Python's multiprocessing module.
if __name__ == "__main__":
    run_pretraining()

The full code for nemotron_pretraining.py looks like this:

import nemo_run as run

from nemo.collections import llm


def configure_recipe(nodes: int = 1, gpus_per_node: int = 2):
    recipe = llm.nemotron3_4b.pretrain_recipe(
        dir="/checkpoints/nemotron", # Path to store checkpoints
        name="nemotron_pretraining",
        tensor_parallelism=2,
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
        max_steps=100, # Setting a small value for the quickstart
    )

    recipe.trainer.val_check_interval = 100
    return recipe

def local_executor_torchrun(nodes: int = 1, devices: int = 2) -> run.LocalExecutor:
    # Env vars for jobs are configured here
    env_vars = {
        "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
        "NCCL_NVLS_ENABLE": "0",
        "NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
        "NVTE_ASYNC_AMAX_REDUCTION": "1",
    }

    executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)

    return executor

def run_pretraining():
    recipe = configure_recipe()
    executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)

    run.run(recipe, executor=executor, name="nemotron3_4b_pretraining")

# This condition is necessary for the script to be compatible with Python's multiprocessing module.
if __name__ == "__main__":
    run_pretraining()
  1. Run the file using the following command:

python nemotron_pretraining.py

Here’s a recording that demonstrates all the steps mentioned above, leading up to the start of pretraining:

Change the Number of GPUs#

Let’s see how we can change the configuration to run on just 1 GPU instead of 2. All you need to do is change the configuration in run_pretraining, as shown below:

def run_pretraining():
    recipe = configure_recipe()
    executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)

    # Change to 1 GPU

    # Change executor params
    executor.ntasks_per_node = 1
    executor.env_vars["CUDA_VISIBLE_DEVICES"] = "0"

    # Change recipe params

    # The default number of layers comes from the recipe in nemo where num_layers is 32
    # Ref: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/llm/gpt/model/nemotron.py
    # To run on 1 GPU without TP, we can reduce the number of layers to 8 by setting recipe.model.config.num_layers = 8
    recipe.model.config.num_layers = 8
    # We also need to set TP to 1, since we had used 2 for 2 GPUs.
    recipe.trainer.strategy.tensor_model_parallel_size = 1
    # Lastly, we need to set devices to 1 in the trainer.
    recipe.trainer.devices = 1

    run.run(recipe, executor=executor, name="nemotron3_4b_pretraining")

Execute on a Slurm Cluster#

One of the benefits of NeMo-Run is to allow you to easily scale from local to remote Slurm-based clusters. Next, let’s see how we can launch the same pretraining recipe on a Slurm cluster.

Note

Each cluster might have different settings. It is recommended that you reach out to the cluster administrators for specific details.

  1. Define a slurm executor:

def slurm_executor(
    user: str,
    host: str,
    remote_job_dir: str,
    account: str,
    partition: str,
    nodes: int,
    devices: int,
    time: str = "01:00:00",
    custom_mounts: Optional[list[str]] = None,
    custom_env_vars: Optional[dict[str, str]] = None,
    container_image: str = "nvcr.io/nvidia/nemo:dev",
    retries: int = 0,
) -> run.SlurmExecutor:
    if not (user and host and remote_job_dir and account and partition and nodes and devices):
        raise RuntimeError(
            "Please set user, host, remote_job_dir, account, partition, nodes, and devices args for using this function."
        )

    mounts = []
    # Custom mounts are defined here.
    if custom_mounts:
        mounts.extend(custom_mounts)

    # Env vars for jobs are configured here
    env_vars = {
        "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
        "NCCL_NVLS_ENABLE": "0",
        "NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
        "NVTE_ASYNC_AMAX_REDUCTION": "1",
    }
    if custom_env_vars:
        env_vars |= custom_env_vars

    # This defines the slurm executor.
    # We connect to the executor via the tunnel defined by user, host, and remote_job_dir.
    executor = run.SlurmExecutor(
        account=account,
        partition=partition,
        tunnel=run.SSHTunnel(
            user=user,
            host=host,
            job_dir=remote_job_dir, # This is where the results of the run will be stored by default.
            # identity="/path/to/identity/file" OPTIONAL: Provide path to the private key that can be used to establish the SSH connection without entering your password.
        ),
        nodes=nodes,
        ntasks_per_node=devices,
        gpus_per_node=devices,
        mem="0",
        exclusive=True,
        gres="gpu:8",
        packager=run.Packager(),
    )

    executor.container_image = container_image
    executor.container_mounts = mounts
    executor.env_vars = env_vars
    executor.retries = retries
    executor.time = time

    return executor
  1. Replace the local executor with the slurm executor, as shown below:

def run_pretraining_with_slurm():
    recipe = configure_recipe(nodes=1, gpus_per_node=8)
    executor = slurm_executor(
        user="", # TODO: Set the username you want to use
        host="", # TODO: Set the host of your cluster
        remote_job_dir="", # TODO: Set the directory on the cluster where you want to save results
        account="", # TODO: Set the account for your cluster
        partition="", # TODO: Set the partition for your cluster
        container_image="", # TODO: Set the container image you want to use for your job
        # container_mounts=[], TODO: Set any custom mounts
        # custom_env_vars={}, TODO: Set any custom env vars
        nodes=recipe.trainer.num_nodes,
        devices=recipe.trainer.devices,
    )

    run.run(recipe, executor=executor, detach=True, name="nemotron3_4b_pretraining")
  1. Run the pretraining as follows:

if __name__ == "__main__":
    run_pretraining_with_slurm()
python nemotron_pretraining.py

Since we have set detach=True, the process will exit after scheduling the job on the cluster. It will provide information about directories and commands to manage the run/experiment.

Continue Pretraining#

If you want to continue pretraining from a previous checkpoint on long context, you can follow the guide here.

Fine-Tuning#

Note

The default fine-tuning recipe uses the SquadDataModule. If you want to use a real dataset, follow the instructions here.

One of the main benefits of NeMo-Run is that it decouples configuration and execution, allowing us to reuse predefined executors and simply change the recipe. For the purpose of this tutorial, we will include the executor definition so that this section can be followed independently.

Set Up the Prerequisites#

Run the following commands to set up your Hugging Face token for automatic conversion of the model from Hugging Face.

mkdir -p /tokens

# Fetch Huggingface token and export it.
# See https://huggingface.co/docs/hub/en/security-tokens for instructions.
export HF_TOKEN="hf_your_token" # Change this to your Huggingface token

# Save token to /tokens/huggingface
echo "$HF_TOKEN" > /tokens/huggingface

Configure the Recipe#

In this section, we will fine-tune a Llama 3 8B model from Hugging Face on a single GPU. To achieve this, we need to follow two steps:

  1. Convert the checkpoint from Hugging Face to NeMo.

  2. Run fine-tuning using the converted checkpoint from step 1.

We will accomplish this using a NeMo-Run experiment, which allows you to define these two tasks and execute them sequentially with ease. We will create a new file, nemotron_finetuning.py, in the same directory. For the fine-tuning configuration, we will use the Llama3 8b finetuning recipe. This recipe uses LoRA, enabling it to fit on 1 GPU (this example uses a GPU with 48GB of memory).

Let’s first define the configuration for the two tasks:

import nemo_run as run
from nemo.collections import llm

def configure_checkpoint_conversion():
    return run.Partial(
        llm.import_ckpt,
        model=llm.llama3_8b.model(),
        source="hf://meta-llama/Meta-Llama-3-8B",
        overwrite=False,
    )

def configure_finetuning_recipe(nodes: int = 1, gpus_per_node: int = 1):
    recipe = llm.llama3_8b.finetune_recipe(
        dir="/checkpoints/llama3_finetuning", # Path to store checkpoints
        name="llama3_lora",
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
    )

    recipe.trainer.max_steps = 100
    recipe.trainer.num_sanity_val_steps = 0

    # Need to set this to 1 since the default is 2
    recipe.trainer.strategy.context_parallel_size = 1
    recipe.trainer.val_check_interval = 100

    # This is currently required for LoRA/PEFT
    recipe.trainer.strategy.ddp = "megatron"

    return recipe

You can refer to overrides for details on overriding more of the default attributes.

Execute Locally#

Note

You will need to import the checkpoint first by running the recipe returned by configure_checkpoint_conversion(). Skipping this step will most likely result in an error unless you have a pre-converted checkpoint.

Execution should be pretty straightforward, since we will reuse the local executor (but include the definition here for reference). Next, we will define the experiment and launch it. Here’s what it looks like:

def local_executor_torchrun(nodes: int = 1, devices: int = 2) -> run.LocalExecutor:
    # Env vars for jobs are configured here
    env_vars = {
        "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
        "NCCL_NVLS_ENABLE": "0",
        "NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
        "NVTE_ASYNC_AMAX_REDUCTION": "1",
    }

    executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)

    return executor

def run_finetuning():
    import_ckpt = configure_checkpoint_conversion()
    finetune = configure_finetuning_recipe(nodes=1, gpus_per_node=1)

    executor = local_executor_torchrun(nodes=finetune.trainer.num_nodes, devices=finetune.trainer.devices)
    executor.env_vars["CUDA_VISIBLE_DEVICES"] = "0"

    # Set this env var for model download from huggingface
    executor.env_vars["HF_TOKEN_PATH"] = "/tokens/huggingface"

    with run.Experiment("llama3-8b-peft-finetuning") as exp:
        exp.add(import_ckpt, executor=run.LocalExecutor(), name="import_from_hf") # We don't need torchrun for the checkpoint conversion
        exp.add(finetune, executor=executor, name="peft_finetuning")
        exp.run(sequential=True, tail_logs=True) # This will run the tasks sequentially and stream the logs

# Wrap the call in an if __name__ == "__main__": block to work with Python's multiprocessing module.
if __name__ == "__main__":
    run_finetuning()

The full file looks like this:

import nemo_run as run
from nemo.collections import llm


def configure_checkpoint_conversion():
    return run.Partial(
        llm.import_ckpt,
        model=llm.llama3_8b.model(),
        source="hf://meta-llama/Meta-Llama-3-8B",
        overwrite=False,
    )


def configure_finetuning_recipe(nodes: int = 1, gpus_per_node: int = 1):
    recipe = llm.llama3_8b.finetune_recipe(
        dir="/checkpoints/llama3_finetuning",  # Path to store checkpoints
        name="llama3_lora",
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
    )

    recipe.trainer.max_steps = 100
    recipe.trainer.num_sanity_val_steps = 0

    # Async checkpointing doesn't work with PEFT
    recipe.trainer.strategy.ckpt_async_save = False

    # Need to set this to 1 since the default is 2
    recipe.trainer.strategy.context_parallel_size = 1
    recipe.trainer.val_check_interval = 100

    # This is currently required for LoRA/PEFT
    recipe.trainer.strategy.ddp = "megatron"

    return recipe


def local_executor_torchrun(nodes: int = 1, devices: int = 2) -> run.LocalExecutor:
    # Env vars for jobs are configured here
    env_vars = {
        "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
        "NCCL_NVLS_ENABLE": "0",
        "NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
        "NVTE_ASYNC_AMAX_REDUCTION": "1",
    }

    executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)

    return executor


def run_finetuning():
    import_ckpt = configure_checkpoint_conversion()
    finetune = configure_finetuning_recipe(nodes=1, gpus_per_node=1)

    executor = local_executor_torchrun(nodes=finetune.trainer.num_nodes, devices=finetune.trainer.devices)
    executor.env_vars["CUDA_VISIBLE_DEVICES"] = "0"

    # Set this env var for model download from huggingface
    executor.env_vars["HF_TOKEN_PATH"] = "/tokens/huggingface"

    with run.Experiment("llama3-8b-peft-finetuning") as exp:
        exp.add(
            import_ckpt, executor=run.LocalExecutor(), name="import_from_hf"
        )  # We don't need torchrun for the checkpoint conversion
        exp.add(finetune, executor=executor, name="peft_finetuning")
        exp.run(sequential=True, tail_logs=True)  # This will run the tasks sequentially and stream the logs


# Wrap the call in an if __name__ == "__main__": block to work with Python's multiprocessing module.
if __name__ == "__main__":
    run_finetuning()

Here’s a recording showing all the steps above leading up to the start of fine-tuning:

Switch from PEFT to Full Fine-Tuning#

The default recipe uses PEFT for fine-tuning. If you want to use full fine-tuning, you will need to use a minimum of 2 GPUs and pass peft_scheme=None to the recipe.

Warning

When using import_ckpt in NeMo 2.0, ensure your script includes if __name__ == "__main__":. Without this, Python’s multiprocessing won’t initialize threads properly, causing a “Failure to acquire lock” error.

You can change the code as follows:

def configure_finetuning_recipe(nodes: int = 1, gpus_per_node: int = 2, peft_scheme: Optional[str] = None): # Minimum of 2 GPUs
    recipe = llm.llama3_8b.finetune_recipe(
        dir="/checkpoints/llama3_finetuning", # Path to store checkpoints
        name="llama3_lora",
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
        peft_scheme=peft_scheme, # This will disable PEFT and use full fine-tuning
    )

    recipe.trainer.max_steps = 100
    recipe.trainer.num_sanity_val_steps = 0

    # Need to set this to 1 since the default is 2
    recipe.trainer.strategy.context_parallel_size = 1
    recipe.trainer.val_check_interval = 100

    # This is currently required for LoRA/PEFT
    recipe.trainer.strategy.ddp = "megatron"

    return recipe

...

def run_finetuning():
    import_ckpt = configure_checkpoint_conversion()
    finetune = configure_finetuning_recipe(nodes=1, gpus_per_node=2, peft_scheme=None)

    executor = local_executor_torchrun(nodes=finetune.trainer.num_nodes, devices=finetune.trainer.devices)
    executor.env_vars["CUDA_VISIBLE_DEVICES"] = "0,1"

    # Set this env var for model download from huggingface
    executor.env_vars["HF_TOKEN_PATH"] = "/tokens/huggingface"

    with run.Experiment("llama3-8b-peft-finetuning") as exp:
        exp.add(
            import_ckpt, executor=run.LocalExecutor(), name="import_from_hf"
        )  # We don't need torchrun for the checkpoint conversion
        exp.add(finetune, executor=executor, name="peft_finetuning")
        exp.run(sequential=True, tail_logs=True)  # This will run the tasks sequentially and stream the logs

Use a NeMo 2.0 Pretraining Checkpoint as the Base#

In case you already have a pretrained checkpoint using NeMo 2.0, and want to use that as a starting point for fine-tuning instead of the Hugging Face checkpoint, you can do the following:

def run_finetuning():
    finetune = configure_finetuning_recipe(nodes=1, gpus_per_node=1)
    finetune.resume.restore_config.path = "/path/to/pretrained/NeMo-2/checkpoint"

    executor = local_executor_torchrun(nodes=finetune.trainer.num_nodes, devices=finetune.trainer.devices)
    executor.env_vars["CUDA_VISIBLE_DEVICES"] = "0"

    with run.Experiment("llama3-8b-peft-finetuning") as exp:
        exp.add(finetune, executor=executor, name="peft_finetuning")
        exp.run(sequential=True, tail_logs=True)  # This will run the tasks sequentially and stream the logs

Execute on a Slurm Cluster with More Nodes#

You can reuse the slurm executor from above. The experiment can then be configured like:

Note

The import_ckpt configuration should write to a shared filesystem accessible by all nodes in the cluster for multi-node training.

You can control the default cache location by setting the NEMO_HOME environment variable.

Warning

When using import_ckpt in NeMo 2.0, ensure your script includes if __name__ == "__main__":. Without this, Python’s multiprocessing won’t initialize threads properly, causing a “Failure to acquire lock” error.

def run_finetuning_on_slurm():
    import_ckpt = configure_checkpoint_conversion()

    # This will make finetuning run on 2 nodes with 8 GPUs each.
    recipe = configure_finetuning_recipe(gpus_per_node=8, nodes=2)
    executor = slurm_executor(
        ...
        nodes=recipe.trainer.num_nodes,
        devices=recipe.trainer.devices,
        ...
    )
    executor.env_vars["NEMO_HOME"] = "/path/to/a/shared/filesystem"

    # Importing checkpoint always requires only 1 node and 1 task per node
    import_executor = slurm_executor.clone()
    import_executor.nodes = 1
    import_executor.ntasks_per_node = 1
    # Set this env var for model download from huggingface
    import_executor.env_vars["HF_TOKEN_PATH"] = "/tokens/huggingface"

    with run.Experiment("llama3-8b-peft-finetuning-slurm") as exp:
        exp.add(import_ckpt, executor=import_executor, name="import_from_hf")
        exp.add(recipe, executor=executor, name="peft_finetuning")
        exp.run(sequential=True, tail_logs=True)