Important

NeMo 2.0 is an experimental feature and currently released in the dev container only: nvcr.io/nvidia/nemo:dev. Please refer to the Migration Guide for information on getting started.

Serialization

NeMo 2.0 offers the option to capture the initialization arguments for an experiment’s trainer, model, and dataloader. This feature enables precise reconstruction of these objects, facilitating easy reproducibility of experiments.

IOMixin

Serialization is performed using the IOMixin class. This class captures the arguments passed to a class’ __init__ method, which allows for exact restoration of a trainer, model, and datamodule from a given experiment. The following is a simple example:

from nemo.lightning import io

ckpt = io.TrainerContext(model, trainer, extra={"datamodule": data})
## dump the current state
ckpt.io_dump(save_dir)

## restore the serialized state
loaded = io.load_context(save_dir)
## model, trainer and dataloader will be reinitialized using the same args as before
model = loaded.model
trainer = loaded.trainer
datamodule = loaded.extra["datamodule"]

Saving these initialization states can be done automatically via ModelCheckpoint’s enable_nemo_ckpt_io argument. If enable_nemo_ckpt_io=True, IOMixin’s io_dump functionality will be invoked to save the trainer, model, and dataloader initialization states. These states can then be restored using the io.load_context function. Note that this feature is independent from checkpoint loading; once the objects have been instantiated, if you would like to use the weights from a previous run, they still need to be restored from the checkpoint. An example workflow is as follows:

First, run some training and save a checkpoint:

import nemo.lightning as nl
from nemo.collections import llm
from nemo.lightning import io

trainer = nl.Trainer(...)
model = llm.GPTModel(...)
datamodule = llm.PreTrainingDataModule(...)
optim = nl.MegatronOptimizerModule(...)
checkpoint_callback = nl.ModelCheckpoint(
    ...
    enable_nemo_ckpt_io=True,
    ...
)
nemo_logger = nl.NeMoLogger(
    ...
    explicit_log_dir='explicit_dir_test',
    ckpt=checkpoint_callback,
    ...
)
resume = nl.AutoResume(
    resume_if_exists=True,
    resume_ignore_no_checkpoint=True,
)

llm.train(
    model=model,
    data=datamodule,
    trainer=trainer,
    log=nemo_logger,
    resume=resume,
    tokenizer='data',
    optim=opt,
)

In the above example, ModelCheckpoint, NeMoLogger, and AutoResume are responsible for setting up the logging and checkpointing directories and determining when to save and restore checkpoints. More information about these classes can be found in the logging and checkpointing doc.

Once the initialization states have been saved, we can resume the trainer, model, and datamodule from the serialized path. Note that everything not captured by io_dump (e.g. the checkpoint callback, logger and resume) should be reinitialized. Doing so ensures that the logging and checkpointing directories are set up correctly. It also ensures that the appropriate model weights are restored after reinitialization.

import nemo.lightning as nl
from nemo.collections import llm
from nemo.lightning import io

loaded = io.load_context("explicit_dir_test/<PATH TO LATEST CHECKPOINT>")
model = loaded.model
trainer = loaded.trainer
datamodule = loaded.extra["datamodule"]
optim = nl.MegatronOptimizerModule(...) ## optimizer needs to be reinitialized

checkpoint_callback = nl.ModelCheckpoint(
    ...
    enable_nemo_ckpt_io=True,
    ...
)
nemo_logger = nl.NeMoLogger(
    ...
    explicit_log_dir='explicit_dir_test',
    ckpt=checkpoint_callback,
    ...
)
resume = nl.AutoResume( ## handles resuming of the latest checkpoint in `explicit_dir_test`
    resume_if_exists=True,
    resume_ignore_no_checkpoint=True,
)

llm.train(
    model=model,
    data=datamodule,
    trainer=trainer,
    log=nemo_logger,
    resume=resume,
    tokenizer='data',
    optim=opt,
)