Implement FSDP2Strategy#
Overview#
The FSDP2Strategy implements Fully Sharded Data Parallel (FSDP) via PyTorch’s FSDP2 implementation. It enables distributed training with automatic model sharding and mixed precision support.
Features#
Automatic model parallelism
Mixed precision training
Checkpoint management
Deferred optimizer state restoration
Device mesh initialization
Initialize#
To initialize the FSDP2Strategy, use the following arguments:
strategy = FSDP2Strategy(
data_parallel_size="auto",
tensor_parallel_size="auto",
checkpoint_io=None,
mp_policy=None,
parallelize_fn=None,
**kwargs,
)
Arguments:#
data_parallel_size (Union[“auto”, int]): Number of data-parallel replicas.
tensor_parallel_size (Union[“auto”, int]): Number of tensor-parallel groups.
checkpoint_io (optional): Checkpoint I/O handler.
mp_policy (optional): Mixed precision policy.
parallelize_fn (callable, optional): Model parallelization function.
Parallelize#
The parallelize() method applies the sharding process to the model:
strategy.parallelize()
This method ensures that the model is only parallelized once.
Environment Setup#
The setup_environment() method initializes the distributed environment and device mesh:
strategy.setup_environment()
Manage Checkpoints#
Save Checkpoints#
The save_checkpoint() method unshards the checkpoint and saves it to disk:
strategy.save_checkpoint(checkpoint, filepath)
Load Checkpoints#
The load_checkpoint() method loads a checkpoint from disk:
checkpoint = strategy.load_checkpoint(filepath)
Restore Optimizer State#
Optimizer state is deferred until the first training step. Use the following method to store the optimizer state:
strategy.load_optimizer_state_dict(checkpoint)
Train and Evaluate the Model#
Training Step#
The training_step() method defines a single training iteration:
loss = strategy.training_step(batch, batch_idx)
Validation Step#
The validation_step() method defines a validation iteration:
loss = strategy.validation_step(batch, batch_idx)
Test Step#
The test_step() method defines a test iteration:
loss = strategy.test_step(batch, batch_idx)
Prediction Step#
The predict_step() method defines a prediction iteration:
result = strategy.predict_step(batch, batch_idx)
Process DataLoader#
Use process_dataloader() to apply custom data sampling to a DataLoader:
dataloader = strategy.process_dataloader(dataloader)
Retrieve State Dictionary#
Retrieve the model’s state dictionary using lightning_module_state_dict():
state_dict = strategy.lightning_module_state_dict()
Remove Checkpoints#
Remove a checkpoint from the filesystem:
strategy.remove_checkpoint(filepath)
Initialize Tensors#
Use the tensor_init_context() context manager for tensor initialization:
with strategy.tensor_init_context():
# Initialization code
pass