nemo_automodel.components.distributed.pipelining.functional#

Module Contents#

Classes#

Functions#

scale_grads_by_divisor

stage_ids_this_rank

Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule

generate_hf_model_fqn_per_model_part

Generates module names for each pipeline stage for HuggingFace models.

calculate_virtual_stages

_get_hidden_and_vocab_size

Extract hidden_size and vocab_size from a model config.

_precompute_stage_shapes

Precompute input/output meta tensors for each pipeline stage to bypass serial shape inference.

reset_pp_stage_shapes

Reset pipeline stage infrastructure and recompute shapes for a new sequence length.

split_model_into_stages

Splits a HuggingFace model for pipeline parallelism.

build_pipeline_schedule

Builds a pipeline schedule for the given job configuration and stages.

pipeline_model

HF-specific pipeline model splitting.

Data#

API#

nemo_automodel.components.distributed.pipelining.functional.logger#

‘getLogger(…)’

class nemo_automodel.components.distributed.pipelining.functional.ParallelizeFnProtocol#

Bases: typing.Protocol

__call__(
model: torch.nn.Module,
world_mesh: torch.distributed.device_mesh.DeviceMesh,
moe_mesh: torch.distributed.device_mesh.DeviceMesh,
*,
dp_axis_names: tuple[str, ...],
cp_axis_name: str | None = None,
tp_axis_name: str | None = None,
ep_axis_name: str | None = None,
ep_shard_axis_names: tuple[str, ...] | None = None,
) None#
nemo_automodel.components.distributed.pipelining.functional.scale_grads_by_divisor(
stages: list[torch.distributed.pipelining.PipelineStage],
divisor: int,
) None#
nemo_automodel.components.distributed.pipelining.functional.stage_ids_this_rank(
pp_rank: int,
pp_size: int,
num_stages: int,
style: str = 'loop',
) tuple[int]#

Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule

nemo_automodel.components.distributed.pipelining.functional.generate_hf_model_fqn_per_model_part(
num_stages: int,
num_layers: int,
include_embeddings: bool = True,
include_lm_head: bool = True,
include_rotary_emb: bool = True,
include_multimodal_encoders: bool = True,
extra_module_fqns: Optional[list[str]] = None,
fqn_prefix: str = 'model.',
lm_head_fqn: str = 'lm_head',
) list[list[str]]#

Generates module names for each pipeline stage for HuggingFace models.

Parameters:
  • num_stages – Number of pipeline stages

  • num_layers – Total number of transformer layers in the model

  • include_embeddings – Whether to include embedding layer in first stage

  • include_lm_head – Whether to include lm_head in last stage (for CausalLM models)

  • include_multimodal_encoders – Whether to include common vision/audio encoder modules in stage 0

  • extra_module_fqns – Optional list of extra module FQNs to include in stage 0

Returns:

List of lists containing module names for each stage

.. rubric:: Example

generate_hf_model_split(4, 32) might return: [ [“model.embed_tokens”, “model.layers.0”, …, “model.layers.7”], [“model.layers.8”, …, “model.layers.15”], [“model.layers.16”, …, “model.layers.23”], [“model.layers.24”, …, “model.layers.31”, “model.norm”, “lm_head”] ]

nemo_automodel.components.distributed.pipelining.functional.calculate_virtual_stages(
num_layers: int,
layers_per_stage: Optional[int],
pp_size: int,
is_single_stage_schedule: bool,
round_to_pp_multiple: str | None = None,
) tuple[int, int]#
nemo_automodel.components.distributed.pipelining.functional._get_hidden_and_vocab_size(model_config) tuple[int, int]#

Extract hidden_size and vocab_size from a model config.

Handles both flat configs (LLM) and nested configs where these attributes live under text_config (VLM models such as Qwen3-VL, LLaVA, etc.).

nemo_automodel.components.distributed.pipelining.functional._precompute_stage_shapes(
stages: list[torch.distributed.pipelining.PipelineStage],
model_config,
microbatch_size: int,
seq_len: int,
) None#

Precompute input/output meta tensors for each pipeline stage to bypass serial shape inference.

By default, PipelineStage performs shape inference at runtime via a serial P2P chain: stage 0 → send → stage 1 → send → … → stage N-1. This is O(N) in the number of pipeline stages and becomes a bottleneck for large world sizes.

This function sets inputs_meta and _outputs_meta on each stage before the first step() call, so that _shape_inference is never invoked and the serial chain is completely eliminated.

Parameters:
  • stages – The local pipeline stages (already parallelized).

  • model_config – The HuggingFace model config (model.config).

  • microbatch_size – Microbatch size used by the pipeline schedule.

  • seq_len – Sequence length of the input data.

nemo_automodel.components.distributed.pipelining.functional.reset_pp_stage_shapes(
schedule: torch.distributed.pipelining.schedules._PipelineSchedule,
stages: list[torch.distributed.pipelining.PipelineStage],
model_config,
microbatch_size: int,
seq_len: int,
) None#

Reset pipeline stage infrastructure and recompute shapes for a new sequence length.

VLM training produces batches with highly variable sequence lengths (image tokens expand the sequence dramatically). PyTorch’s PipelineStage locks in output shapes and recv buffer sizes on the first schedule.step() call (_stages_initialized = True). Subsequent steps with a different seq_len therefore hit a shape-mismatch error.

This function resets the per-stage infrastructure so that _initialize_stages re-runs on the next step() call. It then calls _precompute_stage_shapes to set the correct shapes analytically — avoiding the expensive real-valued forward pass that _shape_inference would otherwise perform.

Parameters:
  • schedule – The active pipeline schedule.

  • stages – The local pipeline stages for this rank.

  • model_config – The HuggingFace model config (model.config).

  • microbatch_size – Per-microbatch batch size used by the schedule.

  • seq_len – Sequence length of the upcoming batch (e.g. input_ids.shape[1]).

nemo_automodel.components.distributed.pipelining.functional.split_model_into_stages(
model: torch.nn.Module,
pp_mesh: torch.distributed.device_mesh.DeviceMesh,
pp_axis_name: str,
pp_schedule: str,
device: torch.device,
module_names_per_stage: Optional[list[list[str]]] = None,
layers_per_stage: Optional[int] = None,
patch_inner_model: bool = True,
patch_causal_lm_model: bool = True,
round_to_pp_multiple: str | None = None,
) tuple[list[torch.distributed.pipelining.PipelineStage], list[torch.nn.Module]]#

Splits a HuggingFace model for pipeline parallelism.

Parameters:
  • model – The HuggingFace model to split

  • pp_mesh – Pipeline parallel device mesh

  • pp_schedule – Name of pipeline parallelism schedule

  • device – Device to place stages on

  • module_names_per_stage – Optional manual specification of modules per stage

  • num_stages – Number of pipeline stages (used if module_names_per_stage not provided)

Returns:

Tuple of (stages, models) where stages are PipelineStage objects and models are the corresponding model chunks

nemo_automodel.components.distributed.pipelining.functional.build_pipeline_schedule(
pipeline_parallel_schedule_csv: str | None,
pipeline_parallel_schedule: str | None,
microbatch_size: int,
local_batch_size: int,
stages: list[torch.distributed.pipelining.PipelineStage],
loss_fn: Callable,
scale_grads: bool = False,
) torch.distributed.pipelining.schedules._PipelineSchedule#

Builds a pipeline schedule for the given job configuration and stages.

Parameters:
  • pipeline_parallel_schedule_csv (str | None) – The path to the pipeline parallel schedule csv file.

  • pipeline_parallel_schedule (str | None) – The name of the pipeline parallel schedule.

  • microbatch_size (int) – The microbatch size.

  • local_batch_size (int) – The local batch size.

  • stages (list[PipelineStage]) – The stages to be scheduled.

  • loss_fn (Callable) – The loss function.

Returns:

The pipeline schedule for the given stages.

Return type:

_PipelineSchedule

nemo_automodel.components.distributed.pipelining.functional.pipeline_model(
model: torch.nn.Module,
world_mesh: torch.distributed.device_mesh.DeviceMesh,
moe_mesh: torch.distributed.device_mesh.DeviceMesh,
*,
pp_axis_name: str,
dp_axis_names: tuple[str, ...],
cp_axis_name: str | None = None,
tp_axis_name: str | None = None,
ep_axis_name: str | None = None,
ep_shard_axis_names: tuple[str, ...] | None = None,
layers_per_stage: int | None,
pipeline_parallel_schedule_csv: str | None,
pipeline_parallel_schedule: str | None,
microbatch_size: int,
local_batch_size: int,
device: torch.device,
loss_fn: Callable = None,
parallelize_fn: Callable | None = None,
module_fqns_per_model_part: list[list[str]] | None = None,
patch_inner_model: bool = True,
patch_causal_lm_model: bool = True,
scale_grads: bool = False,
round_to_pp_multiple: str | None = None,
patch_stage_backward_maybe_with_nosync: bool = False,
seq_len: int | None = None,
) tuple[torch.distributed.pipelining.schedules._PipelineSchedule, list[torch.nn.Module], bool, bool, list[torch.distributed.pipelining.PipelineStage]]#

HF-specific pipeline model splitting.