nemo_automodel.components.models.diffusion_gemma.fsdp

View as Markdown

FSDP2 sharding for diffusion_gemma under pure FSDP (ep_size=1).

At ep_size=1 there is no MoE mesh, so the model is sharded by the generic :class:~nemo_automodel.components.distributed.parallelizer.DefaultParallelizationStrategy (via FSDP2Manager.parallelize), which applies fully_shard per decoder layer and to the root. The generic fully_shard flattens all of a decoder layer’s parameters into one FSDP unit, which folds each layer’s grouped-expert tensors (moe.experts.{gate_and_up_projs,down_projs}, the bulk of the 26B parameters) into the layer’s single all-gather. That gathers every expert of a layer at once on each forward — a large activation-memory spike for a model that runs the shared stack twice (causal encode + bidirectional decode, plus an optional self-conditioning pass).

fully_shard_diffusion_gemma mirrors deepseek_v4’s fully_shard_deepseek_v4: it makes moe.experts its own FSDP unit (sharded dim-0 on the dp mesh) before wrapping the rest of the decoder layer. Consequences:

  • The grouped-expert parameters become global-[n_experts] Shard(0) DTensors on the dp mesh, so DCP sees the checkpoint’s global expert shape and each rank reads only its shard (no [128] vs [16] size mismatch).
  • During the experts’ forward, FSDP all-gathers their parameters back to the full [n_experts, ...] tensor, so :class:GroupedExperts sees a plain (non-DTensor) tensor and runs with ep_size == 1 — all experts local, no expert-parallel token shuffle. This is pure FSDP, not EP.
  • Experts gather/reshard independently of the rest of the layer, bounding peak memory across the double (encode + decode) pass.
  • moe.experts becomes a distinct FSDPModule that MoEFSDPSyncMixin._iter_fsdp_modules discovers (block.moe.experts).

No expert parallelism is introduced; this is the ep_size=1 path only.

Module Contents

Functions

NameDescription
_fully_shard_onceApply fully_shard to module unless it is already an FSDP unit.
_has_fsdp_stateReturn True if module has already been wrapped by fully_shard.
fully_shard_diffusion_gemmaApply FSDP2 to a diffusion_gemma decoder layer (or any other module).
register_diffusion_gemma_parallel_strategyRegister the diffusion_gemma FSDP2 strategy (idempotent).

API

nemo_automodel.components.models.diffusion_gemma.fsdp._fully_shard_once(
module: torch.nn.Module,
mesh,
mp_policy,
offload_policy,
fsdp_kwargs = {}
) -> torch.nn.Module

Apply fully_shard to module unless it is already an FSDP unit.

nemo_automodel.components.models.diffusion_gemma.fsdp._has_fsdp_state(
module: torch.nn.Module
) -> bool

Return True if module has already been wrapped by fully_shard.

nemo_automodel.components.models.diffusion_gemma.fsdp.fully_shard_diffusion_gemma(
module: torch.nn.Module,
mesh,
mp_policy,
offload_policy = None,
fsdp_kwargs = {}
) -> torch.nn.Module

Apply FSDP2 to a diffusion_gemma decoder layer (or any other module).

For a :class:DiffusionGemmaMoEDecoderLayer, shard its grouped experts (moe.experts) as a separate FSDP unit first, then shard the rest of the layer. All other modules (embeddings, final norm, self-conditioning, the root model) are sharded as a single unit.

Parameters:

module
nn.Module

The module to shard (a decoder layer or the root model).

mesh

The (1-D) data-parallel device mesh to shard across.

mp_policy

FSDP2 mixed-precision policy.

offload_policy
Defaults to None

Optional FSDP2 CPU-offload policy.

**fsdp_kwargs
Defaults to {}

Forwarded to fully_shard (e.g. reshard_after_forward).

Returns: nn.Module

The sharded module.

nemo_automodel.components.models.diffusion_gemma.fsdp.register_diffusion_gemma_parallel_strategy() -> None

Register the diffusion_gemma FSDP2 strategy (idempotent).

Binds :func:fully_shard_diffusion_gemma as the per-module shard function of a :class:DefaultParallelizationStrategy subclass, keyed on the model class name so get_parallelization_strategy selects it at ep_size=1. Invoked at import of model.py (a torch-enabled context), which always runs before the model is parallelized.