nemo_automodel.components.distributed.pipelining.functional

View as Markdown

Module Contents

Classes

NameDescription
ParallelizeFnProtocolCallable protocol for applying distributed parallelism to a model.

Functions

NameDescription
_get_hidden_and_vocab_sizeExtract hidden_size and vocab_size from a model config.
_get_optional_hook-
_precompute_stage_shapesPrecompute input/output meta tensors for each pipeline stage to bypass serial shape inference.
_wrap_stage_forward_to_emit_tensorMake a pipeline stage’s forward emit a tensor, not a ModelOutput.
build_pipeline_scheduleBuilds a pipeline schedule for the given job configuration and stages.
calculate_virtual_stagesCalculate virtual pipeline stages and layers per stage.
generate_hf_model_fqn_per_model_partGenerates module names for each pipeline stage for HuggingFace models.
pipeline_modelHF-specific pipeline model splitting.
reset_pp_stage_shapesReset pipeline stage infrastructure and recompute shapes for a new sequence length.
scale_grads_by_divisorScale pipeline stage gradients by a common divisor when supported.
split_model_into_stagesSplits a HuggingFace model for pipeline parallelism.
stage_ids_this_rankCompute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule

Data

logger

API

class nemo_automodel.components.distributed.pipelining.functional.ParallelizeFnProtocol()
Protocol

Callable protocol for applying distributed parallelism to a model.

nemo_automodel.components.distributed.pipelining.functional.ParallelizeFnProtocol.__call__(
model: torch.nn.Module,
world_mesh: torch.distributed.device_mesh.DeviceMesh,
moe_mesh: torch.distributed.device_mesh.DeviceMesh,
dp_axis_names: tuple[str, ...],
cp_axis_name: str | None = None,
tp_axis_name: str | None = None,
ep_axis_name: str | None = None,
ep_shard_axis_names: tuple[str, ...] | None = None
) -> None
nemo_automodel.components.distributed.pipelining.functional._get_hidden_and_vocab_size(
model_config
) -> tuple[int, int]

Extract hidden_size and vocab_size from a model config.

Handles both flat configs (LLM) and nested configs where these attributes live under text_config (VLM models such as Qwen3-VL, LLaVA, etc.).

nemo_automodel.components.distributed.pipelining.functional._get_optional_hook(
module: object,
name: str
) -> typing.Callable | None
nemo_automodel.components.distributed.pipelining.functional._precompute_stage_shapes(
stages: list[torch.distributed.pipelining.PipelineStage],
model_config,
microbatch_size: int,
seq_len: int,
tensor_dtype: torch.dtype | None = None
) -> None

Precompute input/output meta tensors for each pipeline stage to bypass serial shape inference.

By default, PipelineStage performs shape inference at runtime via a serial P2P chain: stage 0 → send → stage 1 → send → … → stage N-1. This is O(N) in the number of pipeline stages and becomes a bottleneck for large world sizes.

This function sets inputs_meta and _outputs_meta on each stage before the first step() call, so that _shape_inference is never invoked and the serial chain is completely eliminated.

Parameters:

stages
list[PipelineStage]

The local pipeline stages (already parallelized).

model_config

The HuggingFace model config (model.config).

microbatch_size
int

Microbatch size used by the pipeline schedule.

seq_len
int

Sequence length of the input data.

nemo_automodel.components.distributed.pipelining.functional._wrap_stage_forward_to_emit_tensor(
stage_model: torch.nn.Module
) -> None

Make a pipeline stage’s forward emit a tensor, not a ModelOutput.

Custom *ForCausalLM / *ForConditionalGeneration models now return a CausalLMOutputWithPast from forward (fused-linear cross-entropy support, compute_lm_head_logits). torch.distributed.pipelining requires every stage to emit a tensor (or tuple/list of tensors): PipelineStage._validate_fwd_outputs and the inter-stage P2P send/recv treat the output as tensor leaves and read .shape on each, which raises AttributeError: 'CausalLMOutputWithPast' object has no attribute 'shape'.

The stage’s outer forward is left intact (a) for models that opt out of patching via _pp_keep_self_forward and (b) for MoE configs that set patch_causal_lm_model=False so only the inner model is patched. In both cases the kept outer forward returns a ModelOutput. This wraps it so the return is unwrapped to its .logits tensor: compute_lm_head_logits puts the projected logits there on the final stage and the pass-through hidden_states on non-final stages (lm_head is None) — exactly the tensor each stage must forward, and the logits the last-stage loss (PipelineCausalLMLoss / MaskedCrossEntropy) consumes.

No-op when forward already returns a tensor or a tuple (the patched create_pipeline_forward_causal_lm path, and MTP models that emit a (logits, *mtp, seq_idx) tuple), since only ModelOutput is unwrapped.

nemo_automodel.components.distributed.pipelining.functional.build_pipeline_schedule(
pipeline_parallel_schedule_csv: str | None,
pipeline_parallel_schedule: str | None,
microbatch_size: int,
local_batch_size: int,
stages: list[torch.distributed.pipelining.PipelineStage],
loss_fn: typing.Callable,
scale_grads: bool = False
) -> torch.distributed.pipelining.schedules._PipelineSchedule

Builds a pipeline schedule for the given job configuration and stages.

Parameters:

pipeline_parallel_schedule_csv
str | None

The path to the pipeline parallel schedule csv file.

pipeline_parallel_schedule
str | None

The name of the pipeline parallel schedule.

microbatch_size
int

The microbatch size.

local_batch_size
int

The local batch size.

stages
list[PipelineStage]

The stages to be scheduled.

loss_fn
Callable

The loss function.

Returns: _PipelineSchedule

The pipeline schedule for the given stages.

nemo_automodel.components.distributed.pipelining.functional.calculate_virtual_stages(
num_layers: int,
layers_per_stage: typing.Optional[int],
pp_size: int,
is_single_stage_schedule: bool,
round_to_pp_multiple: str | None = None
) -> tuple[int, int]

Calculate virtual pipeline stages and layers per stage.

nemo_automodel.components.distributed.pipelining.functional.generate_hf_model_fqn_per_model_part(
num_stages: int,
num_layers: int,
include_embeddings: bool = True,
include_lm_head: bool = True,
include_rotary_emb: bool = True,
include_multimodal_encoders: bool = True,
extra_module_fqns: typing.Optional[list[str]] = None,
fqn_prefix: str = 'model.',
lm_head_fqn: str = 'lm_head'
) -> list[list[str]]

Generates module names for each pipeline stage for HuggingFace models.

Parameters:

num_stages
int

Number of pipeline stages

num_layers
int

Total number of transformer layers in the model

include_embeddings
boolDefaults to True

Whether to include embedding layer in first stage

include_lm_head
boolDefaults to True

Whether to include lm_head in last stage (for CausalLM models)

include_multimodal_encoders
boolDefaults to True

Whether to include common vision/audio encoder modules in stage 0

extra_module_fqns
Optional[list[str]]Defaults to None

Optional list of extra module FQNs to include in stage 0

Returns: list[list[str]]

List of lists containing module names for each stage

nemo_automodel.components.distributed.pipelining.functional.pipeline_model(
model: torch.nn.Module,
world_mesh: torch.distributed.device_mesh.DeviceMesh,
moe_mesh: torch.distributed.device_mesh.DeviceMesh,
pp_axis_name: str,
dp_axis_names: tuple[str, ...],
cp_axis_name: str | None = None,
tp_axis_name: str | None = None,
ep_axis_name: str | None = None,
ep_shard_axis_names: tuple[str, ...] | None = None,
layers_per_stage: int | None,
pipeline_parallel_schedule_csv: str | None,
pipeline_parallel_schedule: str | None,
microbatch_size: int,
local_batch_size: int,
device: torch.device,
loss_fn: typing.Callable = None,
parallelize_fn: typing.Callable | None = None,
module_fqns_per_model_part: list[list[str]] | None = None,
patch_inner_model: bool = True,
patch_causal_lm_model: bool = True,
scale_grads: bool = False,
round_to_pp_multiple: str | None = None,
patch_stage_backward_maybe_with_nosync: bool = False,
reduce_grad_per_microbatch: bool = False,
seq_len: int | None = None,
tensor_dtype: torch.dtype | None = None
) -> tuple[torch.distributed.pipelining.schedules._PipelineSchedule, list[torch.nn.Module], bool, bool, list[torch.distributed.pipelining.PipelineStage]]

HF-specific pipeline model splitting.

nemo_automodel.components.distributed.pipelining.functional.reset_pp_stage_shapes(
schedule: torch.distributed.pipelining.schedules._PipelineSchedule,
stages: list[torch.distributed.pipelining.PipelineStage],
model_config,
microbatch_size: int,
seq_len: int,
tensor_dtype: torch.dtype | None = None
) -> None

Reset pipeline stage infrastructure and recompute shapes for a new sequence length.

VLM training produces batches with highly variable sequence lengths (image tokens expand the sequence dramatically). PyTorch’s PipelineStage locks in output shapes and recv buffer sizes on the first schedule.step() call (_stages_initialized = True). Subsequent steps with a different seq_len therefore hit a shape-mismatch error.

This function resets the per-stage infrastructure so that _initialize_stages re-runs on the next step() call. It then calls _precompute_stage_shapes to set the correct shapes analytically — avoiding the expensive real-valued forward pass that _shape_inference would otherwise perform.

Parameters:

schedule
_PipelineSchedule

The active pipeline schedule.

stages
list[PipelineStage]

The local pipeline stages for this rank.

model_config

The HuggingFace model config (model.config).

microbatch_size
int

Per-microbatch batch size used by the schedule.

seq_len
int

Sequence length of the upcoming batch (e.g. input_ids.shape[1]).

nemo_automodel.components.distributed.pipelining.functional.scale_grads_by_divisor(
stages: list[torch.distributed.pipelining.PipelineStage],
divisor: int
) -> None

Scale pipeline stage gradients by a common divisor when supported.

nemo_automodel.components.distributed.pipelining.functional.split_model_into_stages(
model: torch.nn.Module,
pp_mesh: torch.distributed.device_mesh.DeviceMesh,
pp_axis_name: str,
pp_schedule: str,
device: torch.device,
module_names_per_stage: typing.Optional[list[list[str]]] = None,
layers_per_stage: typing.Optional[int] = None,
patch_inner_model: bool = True,
patch_causal_lm_model: bool = True,
round_to_pp_multiple: str | None = None
) -> tuple[list[torch.distributed.pipelining.PipelineStage], list[torch.nn.Module]]

Splits a HuggingFace model for pipeline parallelism.

Parameters:

model
torch.nn.Module

The HuggingFace model to split

pp_mesh
DeviceMesh

Pipeline parallel device mesh

pp_schedule
str

Name of pipeline parallelism schedule

device
torch.device

Device to place stages on

module_names_per_stage
Optional[list[list[str]]]Defaults to None

Optional manual specification of modules per stage

num_stages

Number of pipeline stages (used if module_names_per_stage not provided)

Returns: list[PipelineStage]

Tuple of (stages, models) where stages are PipelineStage objects and models are the

nemo_automodel.components.distributed.pipelining.functional.stage_ids_this_rank(
pp_rank: int,
pp_size: int,
num_stages: int,
style: str = 'loop'
) -> tuple[int]

Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule

nemo_automodel.components.distributed.pipelining.functional.logger = logging.getLogger(__name__)