nemo_automodel.components.models.diffusion_gemma.fsdp
nemo_automodel.components.models.diffusion_gemma.fsdp
FSDP2 sharding for diffusion_gemma under pure FSDP (ep_size=1).
At ep_size=1 there is no MoE mesh, so the model is sharded by the generic
:class:~nemo_automodel.components.distributed.parallelizer.DefaultParallelizationStrategy
(via FSDP2Manager.parallelize), which applies fully_shard per decoder
layer and to the root. The generic fully_shard flattens all of a decoder
layer’s parameters into one FSDP unit, which folds each layer’s grouped-expert
tensors (moe.experts.{gate_and_up_projs,down_projs}, the bulk of the 26B
parameters) into the layer’s single all-gather. That gathers every expert of a
layer at once on each forward — a large activation-memory spike for a model that
runs the shared stack twice (causal encode + bidirectional decode, plus an
optional self-conditioning pass).
fully_shard_diffusion_gemma mirrors deepseek_v4’s
fully_shard_deepseek_v4: it makes moe.experts its own FSDP unit
(sharded dim-0 on the dp mesh) before wrapping the rest of the decoder layer.
Consequences:
- The grouped-expert parameters become global-
[n_experts]Shard(0)DTensors on the dp mesh, so DCP sees the checkpoint’s global expert shape and each rank reads only its shard (no[128] vs [16]size mismatch). - During the experts’ forward, FSDP all-gathers their parameters back to the
full
[n_experts, ...]tensor, so :class:GroupedExpertssees a plain (non-DTensor) tensor and runs withep_size == 1— all experts local, no expert-parallel token shuffle. This is pure FSDP, not EP. - Experts gather/reshard independently of the rest of the layer, bounding peak memory across the double (encode + decode) pass.
moe.expertsbecomes a distinctFSDPModulethatMoEFSDPSyncMixin._iter_fsdp_modulesdiscovers (block.moe.experts).
No expert parallelism is introduced; this is the ep_size=1 path only.
Module Contents
Functions
API
Apply fully_shard to module unless it is already an FSDP unit.
Return True if module has already been wrapped by fully_shard.
Apply FSDP2 to a diffusion_gemma decoder layer (or any other module).
For a :class:DiffusionGemmaMoEDecoderLayer, shard its grouped experts
(moe.experts) as a separate FSDP unit first, then shard the rest of the
layer. All other modules (embeddings, final norm, self-conditioning, the
root model) are sharded as a single unit.
Parameters:
The module to shard (a decoder layer or the root model).
The (1-D) data-parallel device mesh to shard across.
FSDP2 mixed-precision policy.
Optional FSDP2 CPU-offload policy.
Forwarded to fully_shard (e.g. reshard_after_forward).
Returns: nn.Module
The sharded module.
Register the diffusion_gemma FSDP2 strategy (idempotent).
Binds :func:fully_shard_diffusion_gemma as the per-module shard function
of a :class:DefaultParallelizationStrategy subclass, keyed on the model
class name so get_parallelization_strategy selects it at ep_size=1.
Invoked at import of model.py (a torch-enabled context), which always
runs before the model is parallelized.