Important

You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.

NeMo 2.0#

In NeMo 1.0, the main interface for configuring experiments is through YAML files. This approach allows for a declarative way to set up experiments, but it has limitations in terms of flexibility and programmatic control. NeMo 2.0 shifts to a Python-based configuration, which offers several advantages:

  • More flexibility and control over the configuration.

  • Better integration with IDEs for code completion and type checking.

  • Easier to extend and customize configurations programmatically.

By adopting PyTorch Lightning’s modular abstractions, NeMo 2.0 makes it easy for users to adapt the framework to their specific use cases and experiment with various configurations. This section offers an overview of the new features in NeMo 2.0 and includes a migration guide with step-by-step instructions for transitioning your models from NeMo 1.0 to NeMo 2.0.

Install NeMo 2.0#

NeMo 2.0 installation instructions can be found in the Getting Started guide.

Quickstart#

Important

In any script you write, please make sure you wrap your code in an if __name__ == "__main__": block. See Working with scripts in NeMo 2.0 for details.

The following is an example of running a simple training loop using NeMo 2.0. This example uses the train API from the NeMo Framework LLM collection. Once you have set up your environment using the instructions above, you’re ready to run this simple train script.

import torch
from nemo import lightning as nl
from nemo.collections import llm
from megatron.core.optimizer import OptimizerConfig

if __name__ == "__main__":
    seq_length = 2048
    global_batch_size = 16

    ## setup the dummy dataset
    data = llm.MockDataModule(seq_length=seq_length, global_batch_size=global_batch_size)

    ## initialize a small GPT model
    gpt_config = llm.GPTConfig(
        num_layers=6,
        hidden_size=384,
        ffn_hidden_size=1536,
        num_attention_heads=6,
        seq_length=seq_length,
        init_method_std=0.023,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        layernorm_epsilon=1e-5,
        make_vocab_size_divisible_by=128,
    )
    model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer)

    ## initialize the strategy
    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=1,
        pipeline_model_parallel_size=1,
        pipeline_dtype=torch.bfloat16,
    )

    ## setup the optimizer
    opt_config = OptimizerConfig(
        optimizer='adam',
        lr=6e-4,
        bf16=True,
    )
    opt = nl.MegatronOptimizerModule(config=opt_config)

    trainer = nl.Trainer(
        devices=1, ## you can change the number of devices to suit your setup
        max_steps=50,
        accelerator="gpu",
        strategy=strategy,
        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
    )

    nemo_logger = nl.NeMoLogger(
        log_dir="test_logdir", ## logs and checkpoints will be written here
    )

    llm.train(
        model=model,
        data=data,
        trainer=trainer,
        log=nemo_logger,
        tokenizer='data',
        optim=opt,
    )

NeMo 2.0 also seamlessly supports scaling to thousands of GPUs using NeMo-Run. For examples of launching large-scale experiments using NeMo-Run, refer to Quickstart with NeMo-Run.

Note

If you are an existing user of NeMo 1.0 and would like to use a NeMo 1.0 dataset in place of the MockDataModule in the example, refer to the data migration guide for instructions.

Extend Quickstart with NeMo-Run#

While Quickstart with NeMo-Run covers how to configure your NeMo 2.0 experiment using NeMo-Run, it is not mandatory to use the configuration system from NeMo-Run. In fact, you can take the Python script from the Quickstart above and launch it on remote clusters directly using NeMo-Run. For more details about NeMo-Run, refer to NeMo-Run Github and the hello_scripts example. Below, we will walk through how to do this.

Prerequisites#

  1. Save the script above as train.py in your working directory.

  2. Install NeMo-Run using the following command:

pip install git+https://github.com/NVIDIA/NeMo-Run.git

Let’s assume that you have the above script saved as train.py in your current working directory.

Launch the Experiment Locally#

Locally here means from your local workstation. It can be a venv in your workstation or an interactive NeMo Docker container.

  1. Write a new file called run.py with the following contents:

import os
import nemo_run as run

if __name__ == "__main__":
    training_job = run.Script(
        inline="""
# This string will get saved to a sh file and executed with bash
# Run any preprocessing commands

# Run the training command
python train.py

# Run any post processing commands
"""
    )

    # Run it locally
    executor = run.LocalExecutor()

    with run.Experiment("nemo_2.0_training_experiment", log_level="INFO") as exp:
        exp.add(training_job, executor=executor, tail_logs=True, name="training")
        # Add more jobs as needed

        # Run the experiment
        exp.run(detach=False)
  1. Then, launch the experiment using the following command:

python run.py

Launch the Experiment on Slurm#

Writing an extra script to just launch locally is not very useful. So let’s see how we can extend run.py to launch the job on any supported NeMo-Run executors. For this tutorial, we will use the Slurm executor.

Note

Each cluster might have different settings. It is recommended that you reach out to the cluster administrators for specific details.

  1. You can define a function to configure your Slurm executor as follows:

def slurm_executor(
    user: str,
    host: str,
    remote_job_dir: str,
    account: str,
    partition: str,
    nodes: int,
    devices: int,
    time: str = "01:00:00",
    custom_mounts: Optional[list[str]] = None,
    custom_env_vars: Optional[dict[str, str]] = None,
    container_image: str = "nvcr.io/nvidia/nemo:dev",
    retries: int = 0,
) -> run.SlurmExecutor:
    if not (user and host and remote_job_dir and account and partition and nodes and devices):
        raise RuntimeError(
            "Please set user, host, remote_job_dir, account, partition, nodes and devices args for using this function."
        )

    mounts = []
    # Custom mounts are defined here.
    if custom_mounts:
        mounts.extend(custom_mounts)

    # Env vars for jobs are configured here
    env_vars = {
        "TRANSFORMERS_OFFLINE": "1",
        "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
        "NCCL_NVLS_ENABLE": "0",
        "NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
        "NVTE_ASYNC_AMAX_REDUCTION": "1",
        "NVTE_FUSED_ATTN": "0",
    }
    if custom_env_vars:
        env_vars |= custom_env_vars

    # This will package the train.py script in the current working directory to the remote cluster.
    # If you are inside a git repo, you can also use https://github.com/NVIDIA/NeMo-Run/blob/main/src/nemo_run/core/packaging/git.py.
    # If the script already exists on your container and you call it with the absolute path, you can also just use `run.Packager()`.
    packager = run.PatternPackager(include_pattern="train.py", relative_path=os.getcwd())

    # This defines the slurm executor.
    # We connect to the executor via the tunnel defined by user, host and remote_job_dir.
    executor = run.SlurmExecutor(
        account=account,
        partition=partition,
        tunnel=run.SSHTunnel(
            user=user,
            host=host,
            job_dir=remote_job_dir, # This is where the results of the run will be stored by default.
            # identity="/path/to/identity/file" OPTIONAL: Provide path to the private key that can be used to establish the SSH connection without entering your password.
        ),
        nodes=nodes,
        ntasks_per_node=devices,
        gpus_per_node=devices,
        mem="0",
        exclusive=True,
        gres="gpu:8",
        packager=packager,
    )

    executor.container_image = container_image
    executor.container_mounts = mounts
    executor.env_vars = env_vars
    executor.retries = retries
    executor.time = time

    return executor
  1. Then, just replace the executor in run.py as:

executor = slurm_executor(...) # pass in args relevant to your cluster
  1. Now, you can run the file with the same command and it will launch your job on the cluster. Similarly, you can define multiple slurm executors for multiple slurm clusters and use them interchangeably, or use any of the supported executors in NeMo-Run.

Where to Find NeMo 2.0#

Currently, the code for NeMo 2.0 can be found in two main locations within the NeMo GitHub repository:

  1. LLM collection: This is the first collection to adopt the NeMo 2.0 APIs. This collection provides implementations of common language models using NeMo 2.0. Currently, the collection supports the following models:

  2. NeMo 2.0 LLM Recipes: Provides comprehensive recipes for pre-training and fine-tuning large language models. Recipes can be easily configured and modified for specific use-cases with the help of NeMo-Run.

  3. NeMo Lightning: Provides custom PyTorch Lightning-compatible objects that make it possible to train Megatron Core-based models using PTL in a modular fashion. NeMo 2.0 employs these objects to train models in a simple and efficient manner.

Pretraining, Supervised Fine-Tuning (SFT), and Parameter-Efficient Fine-Tuning (PEFT) are all supported by the LLM collection. More information about each model can be found in the model-specific documentation linked above.

Long context recipes are also supported with the help of context parallelism. For more information on the available long conext recipes, refer to the long context documentation.

Inference via TRT-LLM is coming soon.

Additional Resources#

Known Issues#

  • TRT-LLM support will be added to NeMo 2.0 in a future release.

  • Instructions for converting a NeMo 1.0 checkpoint to NeMo 2.0 format are coming soon.