nemo_automodel.components.distributed.optimized_tp_plans#

Model-specific parallel plans for tensor parallelism.

This module contains optimized tensor parallel plans for different model architectures including LLaMA, Qwen, Gemma3, and Ministral3 models.

Module Contents#

Classes#

SequenceParallelAllGatherActivation

SequenceParallel that all-gathers activations for sequence parallelism.

VocabParallelEmbedding

RowwiseParallel for nn.Embedding with a MaskPartial mask-buffer fixup.

RotaryEmbedParallel

Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.

Functions#

_parallelize_gemma3

Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions.

get_llama_nemotron_super_tp_plan

Return the tensor parallel plan for Llama / Llama-3.3-Nemotron Super.

get_decilm_nemotron_tp_plan

Return a TP plan for remote-code DeciLM Nemotron-NAS checkpoints.

_parallelize_llama

Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.

_parallelize_ministral3

Parallelizes a Ministral3ForCausalLM model across data and tensor parallel dimensions.

_parallelize_qwen

Parallelizes a Qwen2/Qwen3 causal LM across data and tensor parallel dimensions.

_parallelize_qwen_classification

_parallelize_phi3

_get_class_qualname

Return the fully qualified name of a class as module.qualname.

Data#

API#

class nemo_automodel.components.distributed.optimized_tp_plans.SequenceParallelAllGatherActivation#

Bases: torch.distributed.tensor.parallel.SequenceParallel

SequenceParallel that all-gathers activations for sequence parallelism.

static _prepare_output_fn(use_local_output, mod, outputs, device_mesh)#

Prepare outputs by redistributing sharded DTensors to replicated placement.

class nemo_automodel.components.distributed.optimized_tp_plans.VocabParallelEmbedding#

Bases: torch.distributed.tensor.parallel.RowwiseParallel

RowwiseParallel for nn.Embedding with a MaskPartial mask-buffer fixup.

Some PyTorch versions have a DTensor bug where the MaskPartial placement’s mask_buffer is not populated during the embedding dispatch, leading to::

AssertionError: assert self.mask_buffer.data is not None

This subclass works around the issue by:

  1. Saving the original (un-adjusted) input_ids in a pre-hook.

  2. Recomputing and populating the mask_buffer in the post-hook when the DTensor dispatch failed to do so.

In PyTorch versions where the dispatch works correctly the mask buffer is already populated and the fixup is a no-op.

static _prepare_input_fn(
input_layouts,
desired_input_layouts,
mod,
inputs,
device_mesh,
)#
static _prepare_output_fn(
output_layouts,
use_local_output,
mod,
outputs,
device_mesh,
)#
class nemo_automodel.components.distributed.optimized_tp_plans.RotaryEmbedParallel#

Bases: torch.distributed.tensor.parallel.SequenceParallel

Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.

static _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh)#
static _prepare_output_fn(use_local_output, mod, outputs, device_mesh)#
nemo_automodel.components.distributed.optimized_tp_plans._parallelize_gemma3(
model: Union[transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration],
sequence_parallel: bool = False,
) dict[str, torch.distributed.tensor.parallel.ParallelStyle]#

Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions.

nemo_automodel.components.distributed.optimized_tp_plans.get_llama_nemotron_super_tp_plan(
sequence_parallel: bool = False,
) dict[str, torch.distributed.tensor.parallel.ParallelStyle]#

Return the tensor parallel plan for Llama / Llama-3.3-Nemotron Super.

Same topology as Llama-3.3-Nemotron (e.g. nvidia/Llama-3_3-Nemotron-Super-49B-v1_5): fused QKV, fused gate+up, VocabParallelEmbedding, Row/ColwiseParallel for attention and MLP.

Use this plan explicitly by passing it as tp_shard_plan (dict) or by name llama_nemotron_super_tp_plan when calling fsdp2_strategy_parallelize / _get_parallel_plan.

nemo_automodel.components.distributed.optimized_tp_plans.get_decilm_nemotron_tp_plan(
sequence_parallel: bool = False,
) dict[str, torch.distributed.tensor.parallel.ParallelStyle]#

Return a TP plan for remote-code DeciLM Nemotron-NAS checkpoints.

DeciLM/Nemotron-NAS is close to Llama structurally, but its remote-code forward path performs model-level rotary embedding setup and per-layer block-config dispatch. In practice, the generic base-style plan is a safer match than the Llama-optimized named plan for this architecture.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_llama(
model: transformers.models.llama.modeling_llama.LlamaForCausalLM | None,
sequence_parallel: bool = False,
) dict[str, torch.distributed.tensor.parallel.ParallelStyle]#

Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_ministral3(
model: nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM,
sequence_parallel: bool = False,
) dict[str, torch.distributed.tensor.parallel.ParallelStyle]#

Parallelizes a Ministral3ForCausalLM model across data and tensor parallel dimensions.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_qwen(
model: Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM],
sequence_parallel: bool = False,
) dict[str, torch.distributed.tensor.parallel.ParallelStyle]#

Parallelizes a Qwen2/Qwen3 causal LM across data and tensor parallel dimensions.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_qwen_classification(
model: Union[transformers.models.qwen3.modeling_qwen3.Qwen3ForSequenceClassification],
sequence_parallel: bool = False,
) dict[str, torch.distributed.tensor.parallel.ParallelStyle]#
nemo_automodel.components.distributed.optimized_tp_plans._parallelize_phi3(
model: transformers.models.phi3.modeling_phi3.Phi3ForCausalLM,
sequence_parallel: bool = False,
) dict[str, torch.distributed.tensor.parallel.ParallelStyle]#
nemo_automodel.components.distributed.optimized_tp_plans.LLAMA_NEMOTRON_SUPER_TP_PLAN_NAME#

‘llama_nemotron_super_tp_plan’

nemo_automodel.components.distributed.optimized_tp_plans._get_class_qualname(cls: type) str#

Return the fully qualified name of a class as module.qualname.

Used as a stable dict key for PARALLELIZE_FUNCTIONS instead of the class object itself.

When NeMo-RL uses automodel, force_hf=True is auto-set for models (e.g. LlamaForCausalLM) whose adapter does not implement convert_single_tensor_to_hf. This causes _get_mixin_wrapped_class in model_init.py to create a new class via type(...) that wraps the original with HFCheckpointingMixin. The wrapper copies __module__ and __qualname__ from the original but is a different Python object, so type(model) in PARALLELIZE_FUNCTIONS (identity check) returns False and the default plan is used instead of the optimized one.

String comparison on module.qualname survives this wrapping and correctly identifies the model class.

nemo_automodel.components.distributed.optimized_tp_plans.PARALLELIZE_FUNCTIONS: Dict[str, Callable[..., Dict[str, torch.distributed.tensor.parallel.ParallelStyle]]]#

None