nemo_automodel.components.moe.fsdp_mixin
#
Module Contents#
Classes#
Mixin for managing FSDP synchronization state during MoE model training. |
Functions#
Set the global IS_OPTIM_STEP flag. |
|
Get the global IS_OPTIM_STEP flag. |
|
Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but there are additional state-variables and performance considerations depending on the data parallelism used. This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. |
Data#
API#
- nemo_automodel.components.moe.fsdp_mixin.IS_OPTIM_STEP#
False
- nemo_automodel.components.moe.fsdp_mixin.set_is_optim_step(value: bool) None #
Set the global IS_OPTIM_STEP flag.
- Parameters:
value – Whether we are in an optimization step.
- nemo_automodel.components.moe.fsdp_mixin.get_is_optim_step() bool #
Get the global IS_OPTIM_STEP flag.
- Returns:
Whether we are in an optimization step.
- nemo_automodel.components.moe.fsdp_mixin._iter_fsdp_modules(
- module: torch.nn.Module,
- nemo_automodel.components.moe.fsdp_mixin._configure_fsdp_module(
- fsdp_module: torch.distributed.fsdp.FSDPModule,
- *,
- is_last_backward: bool,
- reshard_after_backward: bool,
- requires_gradient_sync: bool,
- nemo_automodel.components.moe.fsdp_mixin._run_post_backward_hooks(
- fsdp_module: torch.distributed.fsdp.FSDPModule,
- class nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixin#
Mixin for managing FSDP synchronization state during MoE model training.
Controls gradient sync and resharding for FSDP-wrapped modules to optimize performance during gradient accumulation steps.
Usage differs based on parallelism strategy:
Without pipeline parallelism (PP): prepare_for_grad_accumulation() defers sync and resharding at the start of gradient accumulation. prepare_for_final_backward() enables sync and resharding before the last backward pass. FSDP’s autograd hooks automatically handle post-backward synchronization and resharding.
With pipeline parallelism (PP): FSDP state management is handled by patching _PipelineStageBase.backward_maybe_with_nosync (see patched_backward_maybe_with_nosync below). The patch disables sync/resharding for all backwards except the last one before optimizer step, where it manually triggers post-backward hooks and resharding.
- prepare_for_grad_accumulation(pp_enabled: bool = False) None #
Prepare FSDP states before starting gradient accumulation.
- Parameters:
pp_enabled – Whether pipeline parallelism is enabled.
.. note::
When PP is enabled, FSDP state management is handled by the patched _PipelineStageBase.backward_maybe_with_nosync method. This method only applies optimizations for non-PP cases.
- prepare_for_final_backward(pp_enabled: bool = False) None #
Enable gradient sync and resharding for the final backward pass.
- Parameters:
pp_enabled – Whether pipeline parallelism is enabled.
.. note::
When PP is enabled, FSDP state management is handled by the patched _PipelineStageBase.backward_maybe_with_nosync method. This method only applies optimizations for non-PP cases.
- nemo_automodel.components.moe.fsdp_mixin._disable_fsdp_for_moe_module(module: torch.nn.Module) None #
- nemo_automodel.components.moe.fsdp_mixin._run_post_backward_for_moe_module(module: torch.nn.Module) None #
- nemo_automodel.components.moe.fsdp_mixin.patched_backward_maybe_with_nosync(
- self,
- backward_type,
- bwd_kwargs: dict,
- last_backward: bool = False,
Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but there are additional state-variables and performance considerations depending on the data parallelism used. This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries.