Important
You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.
Implement FSDP2Strategy#
Overview#
The FSDP2Strategy implements Fully Sharded Data Parallel (FSDP) via PyTorch’s FSDP2 implementation. It enables distributed training with automatic model sharding and mixed precision support.
Features#
Automatic model parallelism
Mixed precision training
Checkpoint management
Deferred optimizer state restoration
Device mesh initialization
Initialize#
To initialize the FSDP2Strategy, use the following arguments:
strategy = FSDP2Strategy(
data_parallel_size="auto",
tensor_parallel_size="auto",
checkpoint_io=None,
mp_policy=None,
parallelize_fn=None,
**kwargs,
)
Arguments:#
data_parallel_size (Union[“auto”, int]): Number of data-parallel replicas.
tensor_parallel_size (Union[“auto”, int]): Number of tensor-parallel groups.
checkpoint_io (optional): Checkpoint I/O handler.
mp_policy (optional): Mixed precision policy.
parallelize_fn (callable, optional): Model parallelization function.
Parallelize#
The parallelize() method applies the sharding process to the model:
strategy.parallelize()
This method ensures that the model is only parallelized once.
Environment Setup#
The setup_environment() method initializes the distributed environment and device mesh:
strategy.setup_environment()
Manage Checkpoints#
Save Checkpoints#
The save_checkpoint() method unshards the checkpoint and saves it to disk:
strategy.save_checkpoint(checkpoint, filepath)
Load Checkpoints#
The load_checkpoint() method loads a checkpoint from disk:
checkpoint = strategy.load_checkpoint(filepath)
Restore Optimizer State#
Optimizer state is deferred until the first training step. Use the following method to store the optimizer state:
strategy.load_optimizer_state_dict(checkpoint)
Train and Evaluate the Model#
Training Step#
The training_step() method defines a single training iteration:
loss = strategy.training_step(batch, batch_idx)
Validation Step#
The validation_step() method defines a validation iteration:
loss = strategy.validation_step(batch, batch_idx)
Test Step#
The test_step() method defines a test iteration:
loss = strategy.test_step(batch, batch_idx)
Prediction Step#
The predict_step() method defines a prediction iteration:
result = strategy.predict_step(batch, batch_idx)
Process DataLoader#
Use process_dataloader() to apply custom data sampling to a DataLoader:
dataloader = strategy.process_dataloader(dataloader)
Retrieve State Dictionary#
Retrieve the model’s state dictionary using lightning_module_state_dict():
state_dict = strategy.lightning_module_state_dict()
Remove Checkpoints#
Remove a checkpoint from the filesystem:
strategy.remove_checkpoint(filepath)
Initialize Tensors#
Use the tensor_init_context() context manager for tensor initialization:
with strategy.tensor_init_context():
# Initialization code
pass