nemo_automodel.components.distributed.optimized_tp_plans
nemo_automodel.components.distributed.optimized_tp_plans
Model-specific parallel plans for tensor parallelism.
This module contains optimized tensor parallel plans for different model architectures including LLaMA, Qwen, Gemma3, and Ministral3 models.
Module Contents
Classes
Functions
Data
LLAMA_NEMOTRON_SUPER_TP_PLAN_NAME
API
Bases: SequenceParallel
Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.
Bases: SequenceParallel
SequenceParallel that all-gathers activations for sequence parallelism.
Prepare outputs by redistributing sharded DTensors to replicated placement.
Bases: RowwiseParallel
RowwiseParallel for nn.Embedding with a MaskPartial mask-buffer fixup.
Some PyTorch versions have a DTensor bug where the MaskPartial
placement’s mask_buffer is not populated during the embedding
dispatch, leading to::
AssertionError: assert self.mask_buffer.data is not None
This subclass works around the issue by:
- Saving the original (un-adjusted)
input_idsin a pre-hook. - Recomputing and populating the
mask_bufferin the post-hook when the DTensor dispatch failed to do so.
In PyTorch versions where the dispatch works correctly the mask buffer is already populated and the fixup is a no-op.
Return the fully qualified name of a class as module.qualname.
Used as a stable dict key for PARALLELIZE_FUNCTIONS instead of the class object itself.
When NeMo-RL uses automodel, force_hf=True is auto-set for models
(e.g. LlamaForCausalLM) whose adapter does not implement
convert_single_tensor_to_hf. This causes _get_mixin_wrapped_class
in model_init.py to create a new class via type(...) that wraps
the original with HFCheckpointingMixin. The wrapper copies
__module__ and __qualname__ from the original but is a different
Python object, so type(model) in PARALLELIZE_FUNCTIONS (identity
check) returns False and the default plan is used instead of the
optimized one.
String comparison on module.qualname survives this wrapping and
correctly identifies the model class.
Parallelizes a BaichuanForCausalLM model (MLP-only).
Only the MLP is sharded. The attention path stays fully replicated because W_pack uses a non-interleaved [Q|K|V] layout (ColwiseParallel would split it incorrectly) and NormHead (lm_head) is not nn.Linear (ColwiseParallel is unsupported).
Parallelize Falcon-H1 (hybrid Transformer + Mamba2 SSM).
Every Falcon-H1 decoder layer runs an attention branch (self_attn) and a
Mamba2 branch (mamba) in parallel, followed by an MLP (feed_forward).
Only the attention and MLP linears are tensor-parallel sharded; the Mamba2
mixer stays replicated because its SSM scan / causal conv1d are not
TP-shardable with stock kernels (same approach as Qwen3.5’s GatedDeltaNet
linear-attention branch).
A dedicated plan is required because HuggingFace ships only
_tp_plan = {"lm_head": "colwise_gather_output"} for FalconH1, and its MLP
is named feed_forward (not mlp). The generic llama-style fallback plan
therefore matches neither the HF plan (the colwise_gather_output style is
rejected) nor the MLP module names, leaving the dominant feed_forward
weights replicated across TP ranks — which OOMs large variants such as
Falcon-H1-34B even under LoRA.
sequence_parallel is accepted for signature compatibility but ignored: the
parallel Mamba2 branch emits non-sequence-parallel activations that cannot be
combined with sequence-parallel attention outputs.
Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions.
Parallelizes a Gemma4ForConditionalGeneration model across tensor parallel dimensions.
Gemma4 VLM uses model.language_model.{embed_tokens, layers.*} for the text backbone, identical to the Gemma3 VLM layout.
Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.
Parallelizes a Ministral3ForCausalLM model across data and tensor parallel dimensions.
TP plan for Mistral3ForConditionalGeneration (and subclasses like
Mistral3FP8VLMForConditionalGeneration). The Ministral3 text backbone
lives at model.language_model.{embed_tokens, layers.*}; vision_tower
and multi_modal_projector stay replicated across TP ranks.
TP plan for NemotronLabsDiffusionModel (Nemotron-Labs-Diffusion).
Same shape as :func:_parallelize_ministral3 but the model uses
encoder.* (not model.*) and the output head is diffusion_head
(not lm_head).
Parallelizes a PhiForCausalLM (Phi-2) model across tensor parallel dimensions.
Phi-2 uses self_attn.dense instead of self_attn.o_proj and
mlp.fc1/mlp.fc2 instead of mlp.gate_proj/mlp.up_proj/mlp.down_proj.
Parallelizes a Qwen2/Qwen3 causal LM across data and tensor parallel dimensions.
Parallelize Qwen3.5 VLM by reusing transformers’ base_model_tp_plan.
Qwen3.5 has mixed attention: full self_attn (every 4th layer) + linear_attn (GatedDeltaNet). The transformers-provided base_model_tp_plan covers only self_attn + MLP — linear_attn is not TP-shardable with stock kernels.
Return a TP plan for remote-code DeciLM Nemotron-NAS checkpoints.
DeciLM/Nemotron-NAS is close to Llama structurally, but its remote-code forward path performs model-level rotary embedding setup and per-layer block-config dispatch. In practice, the generic base-style plan is a safer match than the Llama-optimized named plan for this architecture.
Return the tensor parallel plan for Llama / Llama-3.3-Nemotron Super.
Same topology as Llama-3.3-Nemotron (e.g. nvidia/Llama-3_3-Nemotron-Super-49B-v1_5): fused QKV, fused gate+up, VocabParallelEmbedding, Row/ColwiseParallel for attention and MLP.
Use this plan explicitly by passing it as tp_shard_plan (dict) or by name
llama_nemotron_super_tp_plan when calling fsdp2_strategy_parallelize / _get_parallel_plan.