Important

You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.

Migrate Pre-Training from NeMo 1.0 to NeMo 2.0#

NeMo 1.0 (Previous Release)#

In NeMo 1.0, pre-training is configured using megatron_gpt_config.yaml and launched with megatron_gpt_pretaining.py.

NeMo 2.0 (New Release)#

NeMo 2.0 introduces a Pythonic and modular approach to configuring experiments. For detailed instructions on migrating your NeMo 1.0 YAML configurations to NeMo 2.0, refer to the additional documents in this migration guide:

In addition, NeMo 2.0 is compatible with NeMo-Run, which streamlines the configuration and execution of NeMo experiments. Refer to the NeMo-Run documentation for more.

The llm.train API can be used to run pre-training in NeMo 2.0, as follows:

import torch
from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from megatron.core.optimizer import OptimizerConfig

### set up your GPT model config
gpt_config = llm.GPTConfig(
   num_layers=12,
   hidden_size=384,
   ffn_hidden_size=1536,
   num_attention_heads=6,
   seq_length=2048,
   init_method_std=0.023,
   hidden_dropout=0.1,
   attention_dropout=0.1,
   layernorm_epsilon=1e-5,
   make_vocab_size_divisible_by=128,

)

### other docs in this section explain how to configure these
tokenizer = get_nmt_tokenizer("megatron", "GPT2BPETokenizer")
data = llm.PreTrainingDataModule(
   paths={
      "train": [0.75, '/my/traindata1', 0.25, '/my/traindata2'],
      "validation": '/my/validdata1',
      "test": '/my/testdata`',
   },
   global_batch_size=4,
   micro_batch_size=2,
   num_workers=8,
   pin_memory=True,
   seq_length=2048,
   tokenizer=tokenizer
)
model = llm.GPTModel(
   gpt_config,
   tokenizer=data.tokenizer
)
strategy = nl.MegatronStrategy(
   tensor_model_parallel_size=2,
   pipeline_model_parallel_size=2,
   virtual_pipeline_model_parallel_size=None,
   context_parallel_size=1,
   sequence_parallel=True,
   expert_model_parallel_size=1,
)
optim = nl.MegatronOptimizerModule(
   config=OptimizerConfig(
      optimizer="adam",
      lr=0.001,
      use_distributed_optimizer=True
   ),
   lr_scheduler=nl.lr_scheduler.CosineAnnealingScheduler(),
)
trainer = nl.Trainer(
   num_nodes=16,
   devices=8,
   accelerator="gpu",
   plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
   max_epochs=None,
   max_steps=75000,
   max_time="05:00:00:00",
   log_every_n_steps=10,
   val_check_interval=2000,
   limit_val_batches=50,
   limit_test_batches=50,
   accumulate_grad_batches=1,
   gradient_clip_val=1.0,
)
nemo_logger = nl.NeMoLogger(
   log_dir="your/log/ckpt/dir/here"
)
resume = None # None for pretraining from scratch

llm.train(
   model=model,
   data=data,
   trainer=trainer,
   log=nemo_logger,
   tokenizer=tokenizer,
   resume=resume,
   optim=opt,
)

In addition to the generic GPTModel used in the example above, we also support Gemma, Llama, Mistral, and Mixtral models. For other modules, please refer to Tokenizer, Data Module, Megatron Strategy, Megatron Optimizer Module, Trainer, and NeMo Logger in the NeMo 2.0 guide.

Command Line Interface#

You can also use existing recipes via the NeMo CLI (provided by NeMo-Run).:

For example:

nemo llm pretrain --factory llama3_8b
  • llama3_8b could be replaced by other recipe included in NeMo 2.0 (e.g. mixtral_8x7b)

  • For Long Context recipes with sequence length greater than 16k, we only support the pretrain task since there is no actual Long Context finetuning use case yet.

Continue Training#

In NeMo 2.0, you can set up a resume path to continue training from a NeMo checkpoint. For example:

... # same as above
resume = llm.default_resume()
resume.resume_from_path = "path/to/NeMo/checkpoint/you/want/to/resume/from"

llm.train(
   model=model,
   data=data,
   trainer=trainer,
   log=nemo_logger,
   tokenizer=tokenizer,
   resume=resume,
   optim=opt,
)

If resume.resume_from_path is not set, the training will try resume from logger’s log_dir if checkpoint exists.

Command Line Interface#

For continue training from a checkpoint using a existing recipes, you can also add the resume.resume_from_path="to/some/path" option to the command. Users can also specifiy other options like sequence length in this manner. For example:

nemo llm pretrain --factory llama3_8b resume.resume_from_path="to/some/path" model.config.seq_length=4096 data.seq_length=4096

Migration Steps#

  1. Migrate your NeMo 1.0 YAML to NeMo 2.0 using the other documents in the migration guide.

  2. Run pre-training using the llm.train API.