nemo_automodel.components.distributed.pipelining.autopipeline

View as Markdown

Module Contents

Classes

NameDescription
AutoPipelineOrchestrates pipeline-parallel training on top of torch.distributed.pipelining.
PipelineInfoRuntime state produced by pipeline-parallel setup.

Data

logger

API

class nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline(
world_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
moe_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
pp_axis_name: str = 'pp',
dp_axis_names: tuple[str, ...] = ('dp',),
cp_axis_name: typing.Optional[str] = None,
tp_axis_name: typing.Optional[str] = None,
ep_axis_name: typing.Optional[str] = None,
ep_shard_axis_names: typing.Optional[tuple[str, ...]] = None,
pp_schedule: typing.Optional[str] = '1f1b',
pp_schedule_csv: typing.Optional[str] = None,
pp_microbatch_size: int = 1,
pp_batch_size: int = 1,
layers_per_stage: typing.Optional[int] = None,
round_virtual_stages_to_pp_multiple: typing.Optional[typing.Literal['up', 'down']] = None,
module_fqns_per_model_part: typing.Optional[list[list[str]]] = None,
patch_inner_model: bool = True,
patch_causal_lm_model: bool = True,
patch_stage_backward_maybe_with_nosync: bool = False,
defer_fsdp_grad_sync: bool = True,
device: typing.Optional[torch.device] = None,
dtype: typing.Optional[torch.dtype] = None,
scale_grads_in_schedule: bool = False,
pp_seq_len: typing.Optional[int] = None
)

Orchestrates pipeline-parallel training on top of torch.distributed.pipelining.

_device
device
_info
_pp_current_seq_len
Optional[int] = None
device
device
info
PipelineInfo
parts
list[Module]
pp_mesh
DeviceMesh = self.world_mesh[pp_axis_name]
nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline._count_parameters(
module: torch.nn.Module,
trainable_only: bool = False
) -> int
staticmethod
nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline.build(
model: torch.nn.Module,
loss_fn: typing.Optional[typing.Callable] = None,
parallelize_fn: typing.Optional[nemo_automodel.components.distributed.pipelining.functional.ParallelizeFnProtocol] = None
)

Build the pipeline: validate -> init meta -> split -> schedule.

nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline.debug_summary() -> str
nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline.get_stage_param_counts(
trainable_only: bool = False
) -> list[int]
nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline.get_total_param_count(
trainable_only: bool = False
) -> int
nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline.list_stage_modules() -> list[list[str]]
nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline.log_debug_summary() -> None
nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline.pretty_print_stages(
max_modules_per_stage: int = 16,
trainable_only: bool = False
) -> str
nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline.update_seq_len(
seq_len: int
) -> None

Reset pipeline stage infrastructure for a new sequence length.

VLM training batches can have wildly different sequence lengths across steps (image batches vs. text-only batches). PyTorch’s PipelineStage locks in recv buffer sizes on the first step, causing a shape-mismatch error on later steps with different seq_lens.

Call this before every schedule.step() to update the stage shapes without running an expensive forward pass. A no-op when seq_len has not changed.

Parameters:

seq_len
int

Sequence length of the upcoming batch (input_ids.shape[1]).

nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline.visualize_current_schedule(
filename: typing.Optional[str] = None
) -> None
class nemo_automodel.components.distributed.pipelining.autopipeline.PipelineInfo(
enabled: bool,
schedule: typing.Optional[torch.distributed.pipelining.schedules._PipelineSchedule],
has_first_stage: bool,
has_last_stage: bool,
model_parts: typing.Optional[list[torch.nn.Module]],
stages: typing.Optional[list[torch.distributed.pipelining.stage.PipelineStage]]
)
Dataclass

Runtime state produced by pipeline-parallel setup.

enabled
bool
has_first_stage
bool
has_last_stage
bool
model_parts
Optional[list[Module]]
schedule
Optional[_PipelineSchedule]
stages
Optional[list[PipelineStage]]
nemo_automodel.components.distributed.pipelining.autopipeline.logger = logging.getLogger(__name__)