nemo_automodel.components.distributed.pipelining.config#

Pipeline parallel configuration class.

Design principle:

  • Device mesh (world_mesh, moe_mesh) is passed separately to from_pretrained/from_config

  • PipelineConfig contains scheduling, splitting, and runtime options

  • loss_fn is included here since it’s only used for pipelining

  • Axis names are inferred automatically from device_mesh in _instantiate_pipeline

Usage: from nemo_automodel.components.distributed.pipelining.config import PipelineConfig

config = PipelineConfig(
    pp_schedule="1f1b",
    pp_microbatch_size=2,
    pp_batch_size=8,
    loss_fn=my_loss_fn,
)

Module Contents#

Classes#

PipelineConfig

Configuration for pipeline parallel training.

API#

class nemo_automodel.components.distributed.pipelining.config.PipelineConfig(
pp_schedule: Optional[str] = '1f1b',
pp_schedule_csv: Optional[str] = None,
pp_microbatch_size: int = 1,
pp_batch_size: int = 1,
layers_per_stage: Optional[int] = None,
round_virtual_stages_to_pp_multiple: Optional[Literal[up, down]] = None,
module_fqns_per_model_part: Optional[List[List[str]]] = None,
patch_inner_model: bool = True,
patch_causal_lm_model: bool = True,
patch_stage_backward_maybe_with_nosync: bool = False,
dtype: Optional[torch.dtype] = None,
scale_grads_in_schedule: bool = False,
loss_fn: Optional[Callable] = None,
)#

Configuration for pipeline parallel training.

Note: Device mesh (world_mesh, moe_mesh) is passed separately on the from_pretrained/from_config method signature. Pipeline parallelism is enabled when pp_size > 1. Axis names are inferred automatically from the device mesh structure.

.. attribute:: pp_schedule

Pipeline schedule type. Supported values: “1f1b” (one-forward-one-backward), “gpipe”, “interleaved_1f1b”, “looped_bfs”, “dfs”, “v_schedule”, “zero_bubble”. Defaults to “1f1b”.

Type:

Optional[str]

.. attribute:: pp_schedule_csv

Path to a CSV file defining a custom pipeline schedule. If provided, overrides pp_schedule.

Type:

Optional[str]

.. attribute:: pp_microbatch_size

Size of each microbatch for pipeline execution. pp_batch_size must be divisible by pp_microbatch_size.

Type:

int

.. attribute:: pp_batch_size

Total batch size per pipeline stage. Must be divisible by pp_microbatch_size.

Type:

int

.. attribute:: layers_per_stage

Number of transformer layers per pipeline stage. If None, layers are split evenly across stages.

Type:

Optional[int]

.. attribute:: round_virtual_stages_to_pp_multiple

When using virtual stages (interleaved schedules), round the number of virtual stages to a multiple of pp_size. “up” rounds up, “down” rounds down. If None, no rounding is applied.

Type:

Optional[Literal[“up”, “down”]]

.. attribute:: module_fqns_per_model_part

Explicit specification of which module FQNs belong to each model part/stage. If provided, overrides automatic layer splitting.

Type:

Optional[List[List[str]]]

.. attribute:: patch_inner_model

Apply pipeline patches to the inner model (e.g., the base transformer in a CausalLM wrapper). Defaults to True.

Type:

bool

.. attribute:: patch_causal_lm_model

Apply pipeline patches to the CausalLM wrapper model. Defaults to True.

Type:

bool

.. attribute:: patch_stage_backward_maybe_with_nosync

Patch stage backward to use no_sync context for gradient accumulation efficiency. Useful when combining PP with FSDP.

Type:

bool

.. attribute:: dtype

Data type for pipeline computation. If None, uses the model’s default dtype.

Type:

Optional[torch.dtype]

.. attribute:: scale_grads_in_schedule

Scale gradients within the pipeline schedule (by 1/n_microbatches). If False, gradients must be scaled externally. Defaults to False.

Type:

bool

.. attribute:: loss_fn

Loss function used for pipeline training. Required when pipeline is enabled. The function signature should be compatible with the model’s output format.

Type:

Optional[Callable]

Initialization

pp_schedule: Optional[str]#

‘1f1b’

pp_schedule_csv: Optional[str]#

None

pp_microbatch_size: int#

1

pp_batch_size: int#

1

layers_per_stage: Optional[int]#

None

round_virtual_stages_to_pp_multiple: Optional[Literal[up, down]]#

None

module_fqns_per_model_part: Optional[List[List[str]]]#

None

patch_inner_model: bool#

True

patch_causal_lm_model: bool#

True

patch_stage_backward_maybe_with_nosync: bool#

False

dtype: Optional[torch.dtype]#

None

scale_grads_in_schedule: bool#

False

loss_fn: Optional[Callable]#

None

to_dict() Dict[str, Any]#

Convert config to dictionary.