nemo_automodel.components.distributed.optimized_tp_plans

View as Markdown

Model-specific parallel plans for tensor parallelism.

This module contains optimized tensor parallel plans for different model architectures including LLaMA, Qwen, Gemma3, and Ministral3 models.

Module Contents

Classes

NameDescription
RotaryEmbedParallelCustom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.
SequenceParallelAllGatherActivationSequenceParallel that all-gathers activations for sequence parallelism.
VocabParallelEmbeddingRowwiseParallel for nn.Embedding with a MaskPartial mask-buffer fixup.

Functions

NameDescription
_get_class_qualnameReturn the fully qualified name of a class as module.qualname.
_parallelize_baichuanParallelizes a BaichuanForCausalLM model (MLP-only).
_parallelize_decilm_nemotron-
_parallelize_falcon_h1Parallelize Falcon-H1 (hybrid Transformer + Mamba2 SSM).
_parallelize_gemma3Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions.
_parallelize_gemma4Parallelizes a Gemma4ForConditionalGeneration model across tensor parallel dimensions.
_parallelize_llamaParallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.
_parallelize_ministral3Parallelizes a Ministral3ForCausalLM model across data and tensor parallel dimensions.
_parallelize_mistral3_vlmTP plan for Mistral3ForConditionalGeneration (and subclasses like
_parallelize_nemotron_labs_diffusionTP plan for NemotronLabsDiffusionModel (Nemotron-Labs-Diffusion).
_parallelize_phiParallelizes a PhiForCausalLM (Phi-2) model across tensor parallel dimensions.
_parallelize_phi3-
_parallelize_qwenParallelizes a Qwen2/Qwen3 causal LM across data and tensor parallel dimensions.
_parallelize_qwen3_5_vlmParallelize Qwen3.5 VLM by reusing transformers’ base_model_tp_plan.
_parallelize_qwen_classification-
get_decilm_nemotron_tp_planReturn a TP plan for remote-code DeciLM Nemotron-NAS checkpoints.
get_llama_nemotron_super_tp_planReturn the tensor parallel plan for Llama / Llama-3.3-Nemotron Super.

Data

LLAMA_NEMOTRON_SUPER_TP_PLAN_NAME

PARALLELIZE_FUNCTIONS

API

class nemo_automodel.components.distributed.optimized_tp_plans.RotaryEmbedParallel()

Bases: SequenceParallel

Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.

nemo_automodel.components.distributed.optimized_tp_plans.RotaryEmbedParallel._prepare_input_fn(
sequence_sharding,
mod,
inputs,
device_mesh
)
staticmethod
nemo_automodel.components.distributed.optimized_tp_plans.RotaryEmbedParallel._prepare_output_fn(
use_local_output,
mod,
outputs,
device_mesh
)
staticmethod
class nemo_automodel.components.distributed.optimized_tp_plans.SequenceParallelAllGatherActivation()

Bases: SequenceParallel

SequenceParallel that all-gathers activations for sequence parallelism.

nemo_automodel.components.distributed.optimized_tp_plans.SequenceParallelAllGatherActivation._prepare_output_fn(
use_local_output,
mod,
outputs,
device_mesh
)
staticmethod

Prepare outputs by redistributing sharded DTensors to replicated placement.

class nemo_automodel.components.distributed.optimized_tp_plans.VocabParallelEmbedding()

Bases: RowwiseParallel

RowwiseParallel for nn.Embedding with a MaskPartial mask-buffer fixup.

Some PyTorch versions have a DTensor bug where the MaskPartial placement’s mask_buffer is not populated during the embedding dispatch, leading to::

AssertionError: assert self.mask_buffer.data is not None

This subclass works around the issue by:

  1. Saving the original (un-adjusted) input_ids in a pre-hook.
  2. Recomputing and populating the mask_buffer in the post-hook when the DTensor dispatch failed to do so.

In PyTorch versions where the dispatch works correctly the mask buffer is already populated and the fixup is a no-op.

nemo_automodel.components.distributed.optimized_tp_plans.VocabParallelEmbedding._prepare_input_fn(
input_layouts,
desired_input_layouts,
mod,
inputs,
device_mesh
)
staticmethod
nemo_automodel.components.distributed.optimized_tp_plans.VocabParallelEmbedding._prepare_output_fn(
output_layouts,
use_local_output,
mod,
outputs,
device_mesh
)
staticmethod
nemo_automodel.components.distributed.optimized_tp_plans._get_class_qualname(
cls: type
) -> str

Return the fully qualified name of a class as module.qualname.

Used as a stable dict key for PARALLELIZE_FUNCTIONS instead of the class object itself.

When NeMo-RL uses automodel, force_hf=True is auto-set for models (e.g. LlamaForCausalLM) whose adapter does not implement convert_single_tensor_to_hf. This causes _get_mixin_wrapped_class in model_init.py to create a new class via type(...) that wraps the original with HFCheckpointingMixin. The wrapper copies __module__ and __qualname__ from the original but is a different Python object, so type(model) in PARALLELIZE_FUNCTIONS (identity check) returns False and the default plan is used instead of the optimized one.

String comparison on module.qualname survives this wrapping and correctly identifies the model class.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_baichuan(
model: nemo_automodel.components.models.baichuan.model.BaichuanForCausalLM | None,
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Parallelizes a BaichuanForCausalLM model (MLP-only).

Only the MLP is sharded. The attention path stays fully replicated because W_pack uses a non-interleaved [Q|K|V] layout (ColwiseParallel would split it incorrectly) and NormHead (lm_head) is not nn.Linear (ColwiseParallel is unsupported).

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_decilm_nemotron(
model,
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]
nemo_automodel.components.distributed.optimized_tp_plans._parallelize_falcon_h1(
model,
sequence_parallel: bool = False
) -> typing.Dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Parallelize Falcon-H1 (hybrid Transformer + Mamba2 SSM).

Every Falcon-H1 decoder layer runs an attention branch (self_attn) and a Mamba2 branch (mamba) in parallel, followed by an MLP (feed_forward). Only the attention and MLP linears are tensor-parallel sharded; the Mamba2 mixer stays replicated because its SSM scan / causal conv1d are not TP-shardable with stock kernels (same approach as Qwen3.5’s GatedDeltaNet linear-attention branch).

A dedicated plan is required because HuggingFace ships only _tp_plan = {"lm_head": "colwise_gather_output"} for FalconH1, and its MLP is named feed_forward (not mlp). The generic llama-style fallback plan therefore matches neither the HF plan (the colwise_gather_output style is rejected) nor the MLP module names, leaving the dominant feed_forward weights replicated across TP ranks — which OOMs large variants such as Falcon-H1-34B even under LoRA.

sequence_parallel is accepted for signature compatibility but ignored: the parallel Mamba2 branch emits non-sequence-parallel activations that cannot be combined with sequence-parallel attention outputs.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_gemma3(
model: typing.Union[transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration],
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_gemma4(
model: nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration,
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Parallelizes a Gemma4ForConditionalGeneration model across tensor parallel dimensions.

Gemma4 VLM uses model.language_model.{embed_tokens, layers.*} for the text backbone, identical to the Gemma3 VLM layout.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_llama(
model: transformers.models.llama.modeling_llama.LlamaForCausalLM | None,
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_ministral3(
model: nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM,
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Parallelizes a Ministral3ForCausalLM model across data and tensor parallel dimensions.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_mistral3_vlm(
model,
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]

TP plan for Mistral3ForConditionalGeneration (and subclasses like Mistral3FP8VLMForConditionalGeneration). The Ministral3 text backbone lives at model.language_model.{embed_tokens, layers.*}; vision_tower and multi_modal_projector stay replicated across TP ranks.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_nemotron_labs_diffusion(
model,
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]

TP plan for NemotronLabsDiffusionModel (Nemotron-Labs-Diffusion).

Same shape as :func:_parallelize_ministral3 but the model uses encoder.* (not model.*) and the output head is diffusion_head (not lm_head).

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_phi(
model: transformers.models.phi.modeling_phi.PhiForCausalLM,
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Parallelizes a PhiForCausalLM (Phi-2) model across tensor parallel dimensions.

Phi-2 uses self_attn.dense instead of self_attn.o_proj and mlp.fc1/mlp.fc2 instead of mlp.gate_proj/mlp.up_proj/mlp.down_proj.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_phi3(
model: transformers.models.phi3.modeling_phi3.Phi3ForCausalLM,
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]
nemo_automodel.components.distributed.optimized_tp_plans._parallelize_qwen(
model: typing.Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM],
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Parallelizes a Qwen2/Qwen3 causal LM across data and tensor parallel dimensions.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_qwen3_5_vlm(
model,
sequence_parallel: bool = False
) -> typing.Dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Parallelize Qwen3.5 VLM by reusing transformers’ base_model_tp_plan.

Qwen3.5 has mixed attention: full self_attn (every 4th layer) + linear_attn (GatedDeltaNet). The transformers-provided base_model_tp_plan covers only self_attn + MLP — linear_attn is not TP-shardable with stock kernels.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_qwen_classification(
model: typing.Union[transformers.models.qwen3.modeling_qwen3.Qwen3ForSequenceClassification],
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]
nemo_automodel.components.distributed.optimized_tp_plans.get_decilm_nemotron_tp_plan(
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Return a TP plan for remote-code DeciLM Nemotron-NAS checkpoints.

DeciLM/Nemotron-NAS is close to Llama structurally, but its remote-code forward path performs model-level rotary embedding setup and per-layer block-config dispatch. In practice, the generic base-style plan is a safer match than the Llama-optimized named plan for this architecture.

nemo_automodel.components.distributed.optimized_tp_plans.get_llama_nemotron_super_tp_plan(
sequence_parallel: bool = False
) -> dict[str, torch.distributed.tensor.parallel.ParallelStyle]

Return the tensor parallel plan for Llama / Llama-3.3-Nemotron Super.

Same topology as Llama-3.3-Nemotron (e.g. nvidia/Llama-3_3-Nemotron-Super-49B-v1_5): fused QKV, fused gate+up, VocabParallelEmbedding, Row/ColwiseParallel for attention and MLP.

Use this plan explicitly by passing it as tp_shard_plan (dict) or by name llama_nemotron_super_tp_plan when calling fsdp2_strategy_parallelize / _get_parallel_plan.

nemo_automodel.components.distributed.optimized_tp_plans.LLAMA_NEMOTRON_SUPER_TP_PLAN_NAME = 'llama_nemotron_super_tp_plan'
nemo_automodel.components.distributed.optimized_tp_plans.PARALLELIZE_FUNCTIONS: Dict[str, Callable[..., Dict[str, ParallelStyle]]] = {_get_class_qualname(BaichuanForCausalLM): _parallelize_baichuan, _get_class_qua...