nemo_automodel.components.distributed.fsdp2#

Module Contents#

Classes#

FSDP2Manager

Manager for parallelizing models using FSDP2 with TP, DP, CP sharding.

Functions#

_patch_is_packed_sequence_for_training

Eliminate CPU-GPU sync from flash attention for standard (non-packed) training.

Data#

API#

nemo_automodel.components.distributed.fsdp2.logger#

‘getLogger(…)’

nemo_automodel.components.distributed.fsdp2._patch_is_packed_sequence_for_training() None#

Eliminate CPU-GPU sync from flash attention for standard (non-packed) training.

transformers._is_packed_sequence() returns a GPU bool scalar when batch_size==1, which causes Python’s if to call aten::is_nonzero — a CPU-GPU sync — once per attention layer per forward pass. With FSDP+TP+gradient-checkpointing this fires hundreds of times per iteration.

For standard (non-packed) training sequences are never packed, so returning the Python False immediately is both correct and avoids the sync. Do NOT apply this patch when using packed-sequence training (multiple sequences concatenated into one tensor with position_ids that reset to 0 mid-sequence).

class nemo_automodel.components.distributed.fsdp2.FSDP2Manager(
config: nemo_automodel.components.distributed.config.FSDP2Config,
device_mesh: torch.distributed.device_mesh.DeviceMesh,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
)#

Manager for parallelizing models using FSDP2 with TP, DP, CP sharding.

This manager applies parallelization to the model using a prescribed TP sharding plan. It supports mixed precision and CPU offloading options.

The device mesh must be created externally and passed in.

Parameters:
  • config (FSDP2Config) – Configuration for FSDP2 distributed training.

  • device_mesh (DeviceMesh) – Device mesh for distributed operations.

  • moe_mesh (Optional[DeviceMesh]) – Optional device mesh for expert parallelism.

.. rubric:: Example

from nemo_automodel.components.distributed.config import FSDP2Config

config = FSDP2Config(sequence_parallel=True, activation_checkpointing=True)

device_mesh created externally via create_device_mesh()#

manager = FSDP2Manager(config, device_mesh=device_mesh, moe_mesh=moe_mesh) model = manager.parallelize(model)

Initialization

parallelize(model)#

Parallelizes the given model using FSDP2 and TP sharding strategies.

Parameters:

model (nn.Module) – The model to be parallelized.

Returns:

The parallelized model.

maybe_compile(model)#

Apply per-layer compile after sharding, alongside whole-model compile_model().