nemo_automodel.components.distributed.optimized_tp_plans
#
Model-specific parallel plans for tensor parallelism.
This module contains optimized tensor parallel plans for different model architectures including LLaMA, Qwen, and Gemma3 models.
Module Contents#
Classes#
Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple. |
Functions#
Parallelizes a Gemma3ForCausalLM model across data parallel dimensions. |
|
Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions. |
|
Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions. |
Data#
API#
- class nemo_automodel.components.distributed.optimized_tp_plans.RotaryEmbedParallel(
- *,
- sequence_dim: int = 1,
- use_local_output: bool = False,
Bases:
torch.distributed.tensor.parallel.SequenceParallel
Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.
Initialization
- nemo_automodel.components.distributed.optimized_tp_plans._parallelize_gemma3(
- model: Union[transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration],
- sequence_parallel: bool = False,
Parallelizes a Gemma3ForCausalLM model across data parallel dimensions.
Tensor parallelism is not supported for Gemma3 models because of tied word embeddings.
- nemo_automodel.components.distributed.optimized_tp_plans._parallelize_llama(
- model: transformers.models.llama.modeling_llama.LlamaForCausalLM,
- sequence_parallel: bool = False,
Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.
- nemo_automodel.components.distributed.optimized_tp_plans._parallelize_qwen(
- model: Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM],
- sequence_parallel: bool = False,
Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions.
- nemo_automodel.components.distributed.optimized_tp_plans.PARALLELIZE_FUNCTIONS: Dict[type, Callable[..., Dict[str, torch.distributed.tensor.parallel.ParallelStyle]]]#
None