nemo_automodel.components.moe.fsdp_mixin

View as Markdown

Module Contents

Classes

NameDescription
MoEFSDPSyncMixinMixin for managing FSDP synchronization state during MoE model training.

Functions

NameDescription
_configure_fsdp_module-
_disable_fsdp_for_moe_module-
_iter_fsdp_modules-
_run_post_backward_for_moe_module-
_run_post_backward_hooks-
patched_backward_maybe_with_nosyncWhether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the

API

class nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixin()

Mixin for managing FSDP synchronization state during MoE model training.

Controls gradient sync and resharding for FSDP-wrapped modules to optimize performance during gradient accumulation steps.

Usage differs based on parallelism strategy:

  • Without pipeline parallelism (PP): prepare_for_grad_accumulation() defers sync and resharding at the start of gradient accumulation. prepare_for_final_backward() enables sync and resharding before the last backward pass. FSDP’s autograd hooks automatically handle post-backward synchronization and resharding.
  • With pipeline parallelism (PP): FSDP state management is handled by patching _PipelineStageBase.backward_maybe_with_nosync (see patched_backward_maybe_with_nosync below). The patch disables sync/resharding for all backwards except the last one before optimizer step, where it manually triggers post-backward hooks and resharding.
nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixin.prepare_for_final_backward(
pp_enabled: bool = False
) -> None

Enable gradient sync and resharding for the final backward pass.

Parameters:

pp_enabled
boolDefaults to False

Whether pipeline parallelism is enabled.

nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixin.prepare_for_grad_accumulation(
pp_enabled: bool = False
) -> None

Prepare FSDP states before starting gradient accumulation.

Parameters:

pp_enabled
boolDefaults to False

Whether pipeline parallelism is enabled.

nemo_automodel.components.moe.fsdp_mixin._configure_fsdp_module(
fsdp_module: torch.distributed.fsdp.FSDPModule,
is_last_backward: bool,
reshard_after_backward: bool,
requires_gradient_sync: bool
) -> None
nemo_automodel.components.moe.fsdp_mixin._disable_fsdp_for_moe_module(
module: torch.nn.Module
) -> None
nemo_automodel.components.moe.fsdp_mixin._iter_fsdp_modules(
module: torch.nn.Module
) -> typing.Iterator[torch.distributed.fsdp.FSDPModule]
nemo_automodel.components.moe.fsdp_mixin._run_post_backward_for_moe_module(
module: torch.nn.Module
) -> None
nemo_automodel.components.moe.fsdp_mixin._run_post_backward_hooks(
fsdp_module: torch.distributed.fsdp.FSDPModule
) -> typing.Callable
nemo_automodel.components.moe.fsdp_mixin.patched_backward_maybe_with_nosync(
self,
backward_type,
bwd_kwargs: dict,
last_backward: bool = False
) -> tuple[tuple[typing.Optional[torch.Tensor], ...], typing.Optional[list[dict[str, typing.Any]]]]

Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but there are additional state-variables and performance considerations depending on the data parallelism used. This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries.