nemo_automodel.components.distributed.optimized_tp_plans#
Model-specific parallel plans for tensor parallelism.
This module contains optimized tensor parallel plans for different model architectures including LLaMA, Qwen, Gemma3, and Ministral3 models.
Module Contents#
Classes#
SequenceParallel that all-gathers activations for sequence parallelism. |
|
|
|
Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple. |
Functions#
Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions. |
|
Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions. |
|
Parallelizes a Ministral3ForCausalLM model across data and tensor parallel dimensions. |
|
Parallelizes a Qwen2/Qwen3 causal LM across data and tensor parallel dimensions. |
|
Data#
API#
- class nemo_automodel.components.distributed.optimized_tp_plans.SequenceParallelAllGatherActivation#
Bases:
torch.distributed.tensor.parallel.SequenceParallelSequenceParallel that all-gathers activations for sequence parallelism.
- static _prepare_output_fn(use_local_output, mod, outputs, device_mesh)#
Prepare outputs by redistributing sharded DTensors to replicated placement.
- class nemo_automodel.components.distributed.optimized_tp_plans.VocabParallelEmbedding#
Bases:
torch.distributed.tensor.parallel.RowwiseParallelRowwiseParallelfornn.Embeddingwith aMaskPartialmask-buffer fixup.Some PyTorch versions have a DTensor bug where the
MaskPartialplacement’smask_bufferis not populated during the embedding dispatch, leading to::AssertionError: assert self.mask_buffer.data is not None
This subclass works around the issue by:
Saving the original (un-adjusted)
input_idsin a pre-hook.Recomputing and populating the
mask_bufferin the post-hook when the DTensor dispatch failed to do so.
In PyTorch versions where the dispatch works correctly the mask buffer is already populated and the fixup is a no-op.
- static _prepare_input_fn(
- input_layouts,
- desired_input_layouts,
- mod,
- inputs,
- device_mesh,
- static _prepare_output_fn(
- output_layouts,
- use_local_output,
- mod,
- outputs,
- device_mesh,
- class nemo_automodel.components.distributed.optimized_tp_plans.RotaryEmbedParallel#
Bases:
torch.distributed.tensor.parallel.SequenceParallelCustom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.
- static _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh)#
- static _prepare_output_fn(use_local_output, mod, outputs, device_mesh)#
- nemo_automodel.components.distributed.optimized_tp_plans._parallelize_gemma3(
- model: Union[transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration],
- sequence_parallel: bool = False,
Parallelizes a Gemma3ForCausalLM model across data and tensor parallel dimensions.
- nemo_automodel.components.distributed.optimized_tp_plans._parallelize_llama(
- model: transformers.models.llama.modeling_llama.LlamaForCausalLM,
- sequence_parallel: bool = False,
Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.
- nemo_automodel.components.distributed.optimized_tp_plans._parallelize_ministral3(
- model: nemo_automodel.components.models.mistral3.model.Ministral3ForCausalLM,
- sequence_parallel: bool = False,
Parallelizes a Ministral3ForCausalLM model across data and tensor parallel dimensions.
- nemo_automodel.components.distributed.optimized_tp_plans._parallelize_qwen(
- model: Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM],
- sequence_parallel: bool = False,
Parallelizes a Qwen2/Qwen3 causal LM across data and tensor parallel dimensions.
- nemo_automodel.components.distributed.optimized_tp_plans._parallelize_qwen_classification(
- model: Union[transformers.models.qwen3.modeling_qwen3.Qwen3ForSequenceClassification],
- sequence_parallel: bool = False,
- nemo_automodel.components.distributed.optimized_tp_plans._parallelize_phi3(
- model: transformers.models.phi3.modeling_phi3.Phi3ForCausalLM,
- sequence_parallel: bool = False,
- nemo_automodel.components.distributed.optimized_tp_plans.PARALLELIZE_FUNCTIONS: Dict[type, Callable[..., Dict[str, torch.distributed.tensor.parallel.ParallelStyle]]]#
None