nemo_automodel.components.moe.fsdp_mixin
nemo_automodel.components.moe.fsdp_mixin
Module Contents
Classes
Functions
API
Mixin for managing FSDP synchronization state during MoE model training.
Controls gradient sync and resharding for FSDP-wrapped modules to optimize performance during gradient accumulation steps.
Usage differs based on parallelism strategy:
- Without pipeline parallelism (PP): prepare_for_grad_accumulation() defers sync and resharding at the start of gradient accumulation. prepare_for_final_backward() enables sync and resharding before the last backward pass. FSDP’s autograd hooks automatically handle post-backward synchronization and resharding.
- With pipeline parallelism (PP): FSDP state management is handled by patching _PipelineStageBase.backward_maybe_with_nosync (see patched_backward_maybe_with_nosync below). The patch disables sync/resharding for all backwards except the last one before optimizer step, where it manually triggers post-backward hooks and resharding.
Enable gradient sync and resharding for the final backward pass.
Parameters:
Whether pipeline parallelism is enabled.
Prepare FSDP states before starting gradient accumulation.
Parameters:
Whether pipeline parallelism is enabled.
Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but there are additional state-variables and performance considerations depending on the data parallelism used. This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries.