Important

NeMo 2.0 is an experimental feature and currently released in the dev container only: nvcr.io/nvidia/nemo:dev. Please refer to the Migration Guide for information on getting started.

NeMo 2.0

In NeMo 1.0, the main interface for configuring experiments is through YAML files. This approach allows for a declarative way to set up experiments, but it has limitations in terms of flexibility and programmatic control. NeMo 2.0 shifts to a Python-based configuration, which offers several advantages:

  • More flexibility and control over the configuration.

  • Better integration with IDEs for code completion and type checking.

  • Easier to extend and customize configurations programmatically.

By adopting PyTorch Lightning’s modular abstractions, NeMo 2.0 makes it easy for users to adapt the framework to their specific use cases and experiment with various configurations. This section offers an overview of the new features in NeMo 2.0 and includes a migration guide with step-by-step instructions for transitioning your models from NeMo 1.0 to NeMo 2.0.

Install NeMo 2.0

NeMo 2.0 installation instructions can be found in the Getting Started guide.

Quickstart

The following is an example of running a simple training loop using NeMo 2.0. This example uses the train API from the NeMo Framework LLM collection. Once you have set up your environment using the instructions above, you’re ready to run this simple train script.

import torch
from nemo import lightning as nl
from nemo.collections import llm
from megatron.core.optimizer import OptimizerConfig

if __name__ == "__main__":
    seq_length = 2048
    global_batch_size = 16

    ## setup the dummy dataset
    data = llm.MockDataModule(seq_length=seq_length, global_batch_size=global_batch_size)

    ## initialize a small GPT model
    gpt_config = llm.GPTConfig(
        num_layers=6,
        hidden_size=384,
        ffn_hidden_size=1536,
        num_attention_heads=6,
        seq_length=seq_length,
        init_method_std=0.023,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        layernorm_epsilon=1e-5,
        make_vocab_size_divisible_by=128,
    )
    model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer)

    ## initialize the strategy
    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=1,
        pipeline_model_parallel_size=1,
        pipeline_dtype=torch.bfloat16,
    )

    ## setup the optimizer
    opt_config = OptimizerConfig(
        optimizer='adam',
        lr=6e-4,
        bf16=True,
    )
    opt = nl.MegatronOptimizerModule(config=opt_config)

    trainer = nl.Trainer(
        devices=1, ## you can change the numebr of devices to suit your setup
        max_steps=50,
        accelerator="gpu",
        strategy=strategy,
        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
    )

    nemo_logger = nl.NeMoLogger(
        dir="test_logdir", ## logs and checkpoints will be written here
    )

    llm.train(
        model=model,
        data=data,
        trainer=trainer,
        log=nemo_logger,
        tokenizer='data',
        optim=opt,
    )

NeMo 2.0 also seamlessly supports scaling to thousands of GPUs using NeMo-Run. For examples of launching large-scale experiments using NeMo-Run, refer to NeMo 2.0 Recipes.

Note

If you are an existing user of NeMo 1.0 and would like to use a NeMo 1.0 dataset in place of the MockDataModule in the example, refer to the data migration guide for instructions.

Where to Find NeMo 2.0

Currently, the code for NeMo 2.0 can be found in two main locations within the NeMo GitHub repository:

  1. NeMo Lightning: Provides custom PyTorch Lightning-compatible objects that make it possible to train Megatron Core-based models using PTL in a modular fashion.

  2. LLM collection: This is the first collection to adopt the NeMo 2.0 APIs. This collection provides implementations of common language models using NeMo 2.0. Currently, the collection supports the following models:

Pretraining, Supervised Fine-Tuning (SFT), and Parameter-Efficient Fine-Tuning (PEFT) are all supported by the LLM collection. Long context recipes are also supported with the help of context parallelism. For more information on the available long conext recipes, refer to the long context documentation.

Inference via TRT-LLM is coming soon.

Additional Resources

  • The Design Guide provides an in-depth exploration of the main features of NeMo 2.0.

  • For users familiar with NeMo 1.0, the Migration Guide explains how to migrate your experiments from NeMo 1.0 to NeMo 2.0.

  • NeMo 2.0 Recipes contains additional examples of launching large-scale runs using NeMo 2.0 and NeMo-Run.

Known Issues

  • TRT-LLM support will be added to NeMo 2.0 in a future release.

  • Instructions for converting a NeMo 1.0 checkpoint to NeMo 2.0 format are coming soon.