Implement FSDP2Strategy#

Overview#

The FSDP2Strategy implements Fully Sharded Data Parallel (FSDP) via PyTorch’s FSDP2 implementation. It enables distributed training with automatic model sharding and mixed precision support.

Features#

  • Automatic model parallelism

  • Mixed precision training

  • Checkpoint management

  • Deferred optimizer state restoration

  • Device mesh initialization

Initialize#

To initialize the FSDP2Strategy, use the following arguments:

strategy = FSDP2Strategy(
    data_parallel_size="auto",
    tensor_parallel_size="auto",
    checkpoint_io=None,
    mp_policy=None,
    parallelize_fn=None,
    **kwargs,
)

Arguments:#

  • data_parallel_size (Union[“auto”, int]): Number of data-parallel replicas.

  • tensor_parallel_size (Union[“auto”, int]): Number of tensor-parallel groups.

  • checkpoint_io (optional): Checkpoint I/O handler.

  • mp_policy (optional): Mixed precision policy.

  • parallelize_fn (callable, optional): Model parallelization function.

Parallelize#

The parallelize() method applies the sharding process to the model:

strategy.parallelize()

This method ensures that the model is only parallelized once.

Environment Setup#

The setup_environment() method initializes the distributed environment and device mesh:

strategy.setup_environment()

Manage Checkpoints#

Save Checkpoints#

The save_checkpoint() method unshards the checkpoint and saves it to disk:

strategy.save_checkpoint(checkpoint, filepath)

Load Checkpoints#

The load_checkpoint() method loads a checkpoint from disk:

checkpoint = strategy.load_checkpoint(filepath)

Restore Optimizer State#

Optimizer state is deferred until the first training step. Use the following method to store the optimizer state:

strategy.load_optimizer_state_dict(checkpoint)

Train and Evaluate the Model#

Training Step#

The training_step() method defines a single training iteration:

loss = strategy.training_step(batch, batch_idx)

Validation Step#

The validation_step() method defines a validation iteration:

loss = strategy.validation_step(batch, batch_idx)

Test Step#

The test_step() method defines a test iteration:

loss = strategy.test_step(batch, batch_idx)

Prediction Step#

The predict_step() method defines a prediction iteration:

result = strategy.predict_step(batch, batch_idx)

Process DataLoader#

Use process_dataloader() to apply custom data sampling to a DataLoader:

dataloader = strategy.process_dataloader(dataloader)

Retrieve State Dictionary#

Retrieve the model’s state dictionary using lightning_module_state_dict():

state_dict = strategy.lightning_module_state_dict()

Remove Checkpoints#

Remove a checkpoint from the filesystem:

strategy.remove_checkpoint(filepath)

Initialize Tensors#

Use the tensor_init_context() context manager for tensor initialization:

with strategy.tensor_init_context():
    # Initialization code
    pass