nemo_automodel.components.distributed.config#

Strategy-specific distributed training configuration classes.

Design principle:

  • Size params (dp_size, dp_replicate_size, tp_size, pp_size, cp_size, ep_size) go directly on the from_pretrained/from_config method signature

  • dp_replicate_size is FSDP2-only: raises assertion if passed with non-FSDP2 config

  • Strategy-specific configs contain only additional flags unique to each strategy

  • Managers become normal classes that accept (config, device_mesh)

Usage: from nemo_automodel.components.distributed.config import FSDP2Config, MegatronFSDPConfig, DDPConfig

# FSDP2 with custom options
config = FSDP2Config(sequence_parallel=True, activation_checkpointing=True)

# MegatronFSDP with custom options
config = MegatronFSDPConfig(zero_dp_strategy=3, overlap_grad_reduce=True)

# DDP with activation checkpointing
config = DDPConfig(activation_checkpointing=True)

Module Contents#

Classes#

FSDP2Config

Additional configuration for FSDP2 distributed training.

MegatronFSDPConfig

Additional configuration for MegatronFSDP distributed training.

DDPConfig

Additional configuration for DDP distributed training.

Data#

API#

nemo_automodel.components.distributed.config.DistributedConfig#

None

class nemo_automodel.components.distributed.config.FSDP2Config#

Additional configuration for FSDP2 distributed training.

Note: Size parameters (dp_size, dp_replicate_size, tp_size, pp_size, cp_size, ep_size) are passed separately on the from_pretrained/from_config method signature.

.. attribute:: sequence_parallel

Enable sequence parallelism in TP plan.

Type:

bool

.. attribute:: tp_plan

Custom TP plan. If None, auto-selected based on model type.

Type:

Optional[dict]

.. attribute:: patch_is_packed_sequence

Patch transformers._is_packed_sequence to always return Python False. This does two things: (1) removes a CPU-GPU sync per attention layer (aten::is_nonzero triggered by HF when batch_size==1), and (2) ensures static attention shapes for torch.compile. Safe for standard (non-packed) training only. Disable if using packed-sequence training (position_ids that reset to 0 mid-sequence). Default False.

Type:

bool

.. attribute:: mp_policy

MixedPrecisionPolicy for FSDP2. Can be configured from YAML using the _target_ pattern::

   mp_policy:
     _target_: torch.distributed.fsdp.MixedPrecisionPolicy
     param_dtype: bfloat16
     reduce_dtype: float32
     output_dtype: float32
Type:

Optional[MixedPrecisionPolicy]

.. attribute:: offload_policy

CPUOffloadPolicy for CPU offloading.

Type:

Optional[CPUOffloadPolicy]

.. attribute:: autocast_dtype

If set, wraps the forward pass in torch.autocast(device_type="cuda", dtype=autocast_dtype). Use with output_dtype=float32 in mp_policy to keep the residual stream in fp32 while running matmuls in lower precision. Set to None to disable. Can be set from YAML as a string (e.g. autocast_dtype: bfloat16).

Type:

Optional[torch.dtype]

.. attribute:: activation_checkpointing

Enable activation checkpointing.

Type:

bool

.. attribute:: defer_fsdp_grad_sync

Defer FSDP gradient sync to final micro-batch.

Type:

bool

.. attribute:: backend

Distributed backend.

Type:

str

.. attribute:: enable_async_tensor_parallel

Enable async tensor parallelism via torch._inductor.config._micro_pipeline_tp. Overlaps ReduceScatter with compute in row-parallel layers. Requires sequence_parallel=True (forced automatically with a warning if not set). Also enables symmetric memory for the TP group.

Type:

bool

.. attribute:: enable_compile

Apply per-layer torch.compile to transformer decoder layers (with NO_REENTRANT activation checkpointing inside each compiled layer). Skips whole-model compile so that checkpoint loading does not produce _orig_mod key-prefix mismatches.

Type:

bool

.. attribute:: enable_fsdp2_prefetch

Enable explicit forward/backward prefetch chains between FSDP2 sharded layers. Default True.

Type:

bool

.. attribute:: fsdp2_backward_prefetch_depth

Number of FSDP units to prefetch during backward pass. 2 hides AllGather behind compute; 1 reduces peak memory at a small throughput cost. Default 2.

Type:

int

.. attribute:: fsdp2_forward_prefetch_depth

Number of FSDP units to prefetch during forward pass. Default 1.

Type:

int

sequence_parallel: bool#

False

tp_plan: Optional[dict]#

None

patch_is_packed_sequence: bool#

False

mp_policy: Optional[torch.distributed.fsdp.MixedPrecisionPolicy]#

None

offload_policy: Optional[torch.distributed.fsdp.CPUOffloadPolicy]#

None

autocast_dtype: Optional[torch.dtype]#

None

activation_checkpointing: bool#

False

defer_fsdp_grad_sync: bool#

True

backend: str#

‘nccl’

enable_async_tensor_parallel: bool#

False

enable_compile: bool#

False

enable_fsdp2_prefetch: bool#

False

fsdp2_backward_prefetch_depth: int#

2

fsdp2_forward_prefetch_depth: int#

1

__post_init__()#
to_dict() Dict[str, Any]#

Convert config to dictionary (shallow, preserves policy objects).

class nemo_automodel.components.distributed.config.MegatronFSDPConfig#

Additional configuration for MegatronFSDP distributed training.

Note: Size parameters (dp_size, tp_size, cp_size) are passed separately on the from_pretrained/from_config method signature. MegatronFSDP does not support pp_size, dp_replicate_size, or ep_size.

.. attribute:: sequence_parallel

Enable sequence parallelism in TP plan. Note: Not supported with MegatronFSDP right now.

Type:

bool

.. attribute:: megatron_fsdp_unit_modules

List of unit modules to be wrapped with MegatronFSDP.

Type:

Optional[List[str]]

.. attribute:: zero_dp_strategy

Data parallel sharding strategy.

Type:

int

.. attribute:: init_fsdp_with_meta_device

Initialize MegatronFSDP with meta device if True.

Type:

bool

.. attribute:: grad_reduce_in_fp32

Reduce gradients in fp32 if True.

Type:

bool

.. attribute:: preserve_fp32_weights

Preserve fp32 weights if True.

Type:

bool

.. attribute:: overlap_grad_reduce

Overlap gradient reduction if True.

Type:

bool

.. attribute:: overlap_param_gather

Overlap parameter gathering if True.

Type:

bool

.. attribute:: check_for_nan_in_grad

Check for NaN in gradients if True.

Type:

bool

.. attribute:: average_in_collective

Average in collective if True.

Type:

bool

.. attribute:: disable_bucketing

Disable bucketing if True.

Type:

bool

.. attribute:: calculate_per_token_loss

Calculate per token loss if True.

Type:

bool

.. attribute:: keep_fp8_transpose_cache

Keep fp8 transpose cache when using custom FSDP if True.

Type:

bool

.. attribute:: nccl_ub

Use NCCL UBs if True.

Type:

bool

.. attribute:: fsdp_double_buffer

Use double buffer if True.

Type:

bool

.. attribute:: activation_checkpointing

Enable activation checkpointing for transformer MLP layers to save memory.

Type:

bool

.. attribute:: backend

Distributed backend, e.g. ‘nccl’ or ‘gloo’.

Type:

str

sequence_parallel: bool#

False

tp_plan: dataclasses.InitVar[Optional[dict]]#

None

megatron_fsdp_unit_modules: Optional[List[str]]#

None

zero_dp_strategy: int#

3

init_fsdp_with_meta_device: bool#

False

grad_reduce_in_fp32: bool#

False

preserve_fp32_weights: bool#

False

overlap_grad_reduce: bool#

True

overlap_param_gather: bool#

True

check_for_nan_in_grad: bool#

True

average_in_collective: bool#

False

disable_bucketing: bool#

False

calculate_per_token_loss: bool#

False

keep_fp8_transpose_cache: bool#

False

nccl_ub: bool#

False

fsdp_double_buffer: bool#

False

activation_checkpointing: bool#

False

backend: str#

‘nccl’

__post_init__(tp_plan: Optional[dict])#
to_dict() Dict[str, Any]#

Convert config to dictionary (shallow, preserves objects).

class nemo_automodel.components.distributed.config.DDPConfig#

Additional configuration for DDP distributed training.

Note: DDP does not support tensor parallelism, pipeline parallelism, or expert parallelism. Only dp_size is relevant (inferred from world_size).

.. attribute:: activation_checkpointing

Enable activation checkpointing if True.

Type:

bool

.. attribute:: backend

Distributed backend, e.g. ‘nccl’ or ‘gloo’.

Type:

str

activation_checkpointing: bool#

False

backend: str#

‘nccl’

to_dict() Dict[str, Any]#

Convert config to dictionary.