nemo_automodel.components.distributed.pipelining.config#
Pipeline parallel configuration class.
Design principle:
Device mesh (world_mesh, moe_mesh) is passed separately to from_pretrained/from_config
PipelineConfig contains scheduling, splitting, and runtime options
loss_fn is included here since itâs only used for pipelining
Axis names are inferred automatically from device_mesh in _instantiate_pipeline
Usage: from nemo_automodel.components.distributed.pipelining.config import PipelineConfig
config = PipelineConfig(
pp_schedule="1f1b",
pp_microbatch_size=2,
pp_batch_size=8,
loss_fn=my_loss_fn,
)
Module Contents#
Classes#
Configuration for pipeline parallel training. |
API#
- class nemo_automodel.components.distributed.pipelining.config.PipelineConfig(
- pp_schedule: Optional[str] = '1f1b',
- pp_schedule_csv: Optional[str] = None,
- pp_microbatch_size: int = 1,
- pp_batch_size: int = 1,
- layers_per_stage: Optional[int] = None,
- round_virtual_stages_to_pp_multiple: Optional[Literal[up, down]] = None,
- module_fqns_per_model_part: Optional[List[List[str]]] = None,
- patch_inner_model: bool = True,
- patch_causal_lm_model: bool = True,
- patch_stage_backward_maybe_with_nosync: bool = False,
- dtype: Optional[torch.dtype] = None,
- scale_grads_in_schedule: bool = False,
- loss_fn: Optional[Callable] = None,
Configuration for pipeline parallel training.
Note: Device mesh (world_mesh, moe_mesh) is passed separately on the from_pretrained/from_config method signature. Pipeline parallelism is enabled when pp_size > 1. Axis names are inferred automatically from the device mesh structure.
.. attribute:: pp_schedule
Pipeline schedule type. Supported values: â1f1bâ (one-forward-one-backward), âgpipeâ, âinterleaved_1f1bâ, âlooped_bfsâ, âdfsâ, âv_scheduleâ, âzero_bubbleâ. Defaults to â1f1bâ.
- Type:
Optional[str]
.. attribute:: pp_schedule_csv
Path to a CSV file defining a custom pipeline schedule. If provided, overrides pp_schedule.
- Type:
Optional[str]
.. attribute:: pp_microbatch_size
Size of each microbatch for pipeline execution. pp_batch_size must be divisible by pp_microbatch_size.
- Type:
int
.. attribute:: pp_batch_size
Total batch size per pipeline stage. Must be divisible by pp_microbatch_size.
- Type:
int
.. attribute:: layers_per_stage
Number of transformer layers per pipeline stage. If None, layers are split evenly across stages.
- Type:
Optional[int]
.. attribute:: round_virtual_stages_to_pp_multiple
When using virtual stages (interleaved schedules), round the number of virtual stages to a multiple of pp_size. âupâ rounds up, âdownâ rounds down. If None, no rounding is applied.
- Type:
Optional[Literal[âupâ, âdownâ]]
.. attribute:: module_fqns_per_model_part
Explicit specification of which module FQNs belong to each model part/stage. If provided, overrides automatic layer splitting.
- Type:
Optional[List[List[str]]]
.. attribute:: patch_inner_model
Apply pipeline patches to the inner model (e.g., the base transformer in a CausalLM wrapper). Defaults to True.
- Type:
bool
.. attribute:: patch_causal_lm_model
Apply pipeline patches to the CausalLM wrapper model. Defaults to True.
- Type:
bool
.. attribute:: patch_stage_backward_maybe_with_nosync
Patch stage backward to use no_sync context for gradient accumulation efficiency. Useful when combining PP with FSDP.
- Type:
bool
.. attribute:: dtype
Data type for pipeline computation. If None, uses the modelâs default dtype.
- Type:
Optional[torch.dtype]
.. attribute:: scale_grads_in_schedule
Scale gradients within the pipeline schedule (by 1/n_microbatches). If False, gradients must be scaled externally. Defaults to False.
- Type:
bool
.. attribute:: loss_fn
Loss function used for pipeline training. Required when pipeline is enabled. The function signature should be compatible with the modelâs output format.
- Type:
Optional[Callable]
Initialization
- pp_schedule: Optional[str]#
â1f1bâ
- pp_schedule_csv: Optional[str]#
None
- pp_microbatch_size: int#
1
- pp_batch_size: int#
1
- layers_per_stage: Optional[int]#
None
- round_virtual_stages_to_pp_multiple: Optional[Literal[up, down]]#
None
- module_fqns_per_model_part: Optional[List[List[str]]]#
None
- patch_inner_model: bool#
True
- patch_causal_lm_model: bool#
True
- patch_stage_backward_maybe_with_nosync: bool#
False
- dtype: Optional[torch.dtype]#
None
- scale_grads_in_schedule: bool#
False
- loss_fn: Optional[Callable]#
None
- to_dict() Dict[str, Any]#
Convert config to dictionary.