nemo_automodel.components.distributed.pipelining.autopipeline
#
Module Contents#
Classes#
Orchestrates pipeline-parallel training on top of torch.distributed.pipelining. |
Data#
API#
- nemo_automodel.components.distributed.pipelining.autopipeline.logger#
‘getLogger(…)’
- class nemo_automodel.components.distributed.pipelining.autopipeline.PipelineInfo#
- enabled: bool#
None
- schedule: Optional[torch.distributed.pipelining.schedules._PipelineSchedule]#
None
- has_first_stage: bool#
None
- has_last_stage: bool#
None
- model_parts: Optional[list[torch.nn.Module]]#
None
- stages: Optional[list[torch.distributed.pipelining.stage.PipelineStage]]#
None
- class nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline(
- world_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- pp_axis_name: str = 'pp',
- dp_axis_names: tuple[str, ...] = ('dp',),
- cp_axis_name: Optional[str] = None,
- tp_axis_name: Optional[str] = None,
- ep_axis_name: Optional[str] = None,
- ep_shard_axis_names: Optional[tuple[str, ...]] = None,
- pp_schedule: Optional[str] = '1f1b',
- pp_schedule_csv: Optional[str] = None,
- pp_microbatch_size: int = 1,
- pp_batch_size: int = 1,
- layers_per_stage: Optional[int] = None,
- round_virtual_stages_to_pp_multiple: Optional[Literal[up, down]] = None,
- module_fqns_per_model_part: Optional[list[list[str]]] = None,
- patch_inner_model: bool = True,
- patch_causal_lm_model: bool = True,
- patch_stage_backward_maybe_with_nosync: bool = False,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- scale_grads_in_schedule: bool = False,
- visualization_font_size_offset: int = 0,
Orchestrates pipeline-parallel training on top of torch.distributed.pipelining.
Initialization
- build(
- model: torch.nn.Module,
- *,
- loss_fn: Optional[Callable] = None,
- parallelize_fn: Optional[nemo_automodel.components.distributed.pipelining.functional.ParallelizeFnProtocol] = None,
Build the pipeline: validate -> init meta -> split -> schedule.
- property parts: list[torch.nn.Module]#
- property device: torch.device#
- list_stage_modules() list[list[str]] #
- visualize_current_schedule(
- filename: Optional[str] = None,
- static _count_parameters(
- module: torch.nn.Module,
- trainable_only: bool = False,
- get_stage_param_counts(trainable_only: bool = False) list[int] #
- get_total_param_count(trainable_only: bool = False) int #
- pretty_print_stages(
- max_modules_per_stage: int = 16,
- trainable_only: bool = False,
- debug_summary() str #
- log_debug_summary() None #