nemo_automodel.components.distributed.optimized_tp_plans#

Model-specific parallel plans for tensor parallelism.

This module contains optimized tensor parallel plans for different model architectures including LLaMA, Qwen, and Gemma3 models.

Module Contents#

Classes#

RotaryEmbedParallel

Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.

Functions#

_parallelize_gemma3

Parallelizes a Gemma3ForCausalLM model across data parallel dimensions.

_parallelize_llama

Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.

_parallelize_qwen

Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions.

Data#

API#

class nemo_automodel.components.distributed.optimized_tp_plans.RotaryEmbedParallel(
*,
sequence_dim: int = 1,
use_local_output: bool = False,
)[source]#

Bases: torch.distributed.tensor.parallel.SequenceParallel

Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.

Initialization

static _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh)[source]#
static _prepare_output_fn(use_local_output, mod, outputs, device_mesh)[source]#
nemo_automodel.components.distributed.optimized_tp_plans._parallelize_gemma3(
model: Union[transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration],
sequence_parallel: bool = False,
)[source]#

Parallelizes a Gemma3ForCausalLM model across data parallel dimensions.

Tensor parallelism is not supported for Gemma3 models because of tied word embeddings.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_llama(
model: transformers.models.llama.modeling_llama.LlamaForCausalLM,
sequence_parallel: bool = False,
)[source]#

Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.

nemo_automodel.components.distributed.optimized_tp_plans._parallelize_qwen(
model: Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM],
sequence_parallel: bool = False,
)[source]#

Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions.

nemo_automodel.components.distributed.optimized_tp_plans.PARALLELIZE_FUNCTIONS: Dict[type, Callable[..., Dict[str, torch.distributed.tensor.parallel.ParallelStyle]]]#

None