# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model-specific parallel plans for tensor parallelism.
This module contains optimized tensor parallel plans for different model architectures
including LLaMA, Qwen, and Gemma3 models.
"""
from typing import Callable, Dict, Union
import torch
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
ParallelStyle,
PrepareModuleInput,
PrepareModuleOutput,
RowwiseParallel,
SequenceParallel,
)
from torch.distributed.tensor.placement_types import Replicate, Shard
# Import model classes for type checking and parallel plan mapping
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
[docs]
class RotaryEmbedParallel(SequenceParallel):
"""Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple."""
[docs]
@staticmethod
def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
return type(outputs)([o.to_local() if use_local_output else o for o in outputs])
[docs]
def _parallelize_gemma3(
model: Union[Gemma3ForCausalLM, Gemma3ForConditionalGeneration],
sequence_parallel: bool = False,
):
"""Parallelizes a Gemma3ForCausalLM model across data parallel dimensions.
Tensor parallelism is not supported for Gemma3 models because of tied word embeddings.
"""
if model.__class__ == Gemma3ForConditionalGeneration:
model_prefix = "language_model"
else:
model_prefix = "model"
# For gemma3 models, we don't include the model.embed_tokens and lm_head in the
# parallelization plans because they have tied weights.
base_model_tp_plan = {
f"{model_prefix}.layers.*.self_attn.q_proj": ColwiseParallel(),
f"{model_prefix}.layers.*.self_attn.k_proj": ColwiseParallel(),
f"{model_prefix}.layers.*.self_attn.v_proj": ColwiseParallel(),
f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel(),
f"{model_prefix}.layers.*.mlp.up_proj": ColwiseParallel(),
f"{model_prefix}.layers.*.mlp.gate_proj": ColwiseParallel(),
f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel(),
}
base_model_sp_plan = {
f"{model_prefix}.embed_tokens": PrepareModuleOutput(
output_layouts=Replicate(),
desired_output_layouts=Shard(1),
use_local_output=False,
),
f"{model_prefix}.rotary_emb": RotaryEmbedParallel(use_local_output=True),
f"{model_prefix}.rotary_emb_local": RotaryEmbedParallel(use_local_output=True),
f"{model_prefix}.layers.*.input_layernorm": SequenceParallel(),
f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
f"{model_prefix}.layers.*.post_attention_layernorm": SequenceParallel(),
f"{model_prefix}.layers.*.pre_feedforward_layernorm": SequenceParallel(),
f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
f"{model_prefix}.layers.*.post_feedforward_layernorm": SequenceParallel(),
f"{model_prefix}.norm": SequenceParallel(),
f"{model_prefix}.lm_head": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
use_local_output=True,
),
}
if sequence_parallel:
# Enable sequence parallelism only if TP size > 1
base_model_tp_plan.update(base_model_sp_plan)
return base_model_tp_plan
[docs]
def _parallelize_llama(
model: LlamaForCausalLM,
sequence_parallel: bool = False,
):
"""Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions."""
assert not model.config.tie_word_embeddings, "Tie word embeddings not supported when TP is enabled"
base_model_tp_plan = {
"model.embed_tokens": RowwiseParallel(input_layouts=Replicate()),
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
"model.layers.*.self_attn.o_proj": RowwiseParallel(),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(),
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
}
base_model_sp_plan = {
"model.embed_tokens": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),
"model.norm": SequenceParallel(),
"model.layers.*.input_layernorm": SequenceParallel(),
"model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
"model.layers.*.post_attention_layernorm": SequenceParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
"lm_head": ColwiseParallel(input_layouts=Shard(1), output_layouts=Shard(-1), use_local_output=False),
}
if sequence_parallel:
# Enable sequence parallelism only if TP size > 1
base_model_tp_plan.update(base_model_sp_plan)
return base_model_tp_plan
[docs]
def _parallelize_qwen(
model: Union[Qwen2ForCausalLM, Qwen3ForCausalLM],
sequence_parallel: bool = False,
):
"""Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions."""
class Qwen3QKNorm(SequenceParallel):
@staticmethod
def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
input_tensor = inputs[0]
if isinstance(input_tensor, DTensor):
assert input_tensor.placements == (Shard(dim=2),)
elif isinstance(input_tensor, torch.Tensor):
# assume the input passed in already sharded on the sequence dim and create the DTensor
return DTensor.from_local(input_tensor, device_mesh, sequence_sharding, run_check=False)
else:
raise ValueError(f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}")
assert not model.config.tie_word_embeddings, "Tie word embeddings not supported when TP is enabled"
if sequence_parallel:
base_model_tp_plan = {
"lm_head": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1),
use_local_output=False,
),
"model.embed_tokens": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"model.rotary_emb": RotaryEmbedParallel(),
"model.norm": SequenceParallel(),
"model.layers.*.input_layernorm": SequenceParallel(),
"model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
"model.layers.*.self_attn.q_norm": Qwen3QKNorm(),
"model.layers.*.self_attn.k_norm": Qwen3QKNorm(),
"model.layers.*.post_attention_layernorm": SequenceParallel(),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
}
else:
base_model_tp_plan = {
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
"model.embed_tokens": RowwiseParallel(
input_layouts=Replicate(),
),
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
"model.layers.*.self_attn.o_proj": RowwiseParallel(),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(),
}
return base_model_tp_plan
# Create the model-specific parallel plan mapping
PARALLELIZE_FUNCTIONS: Dict[type, Callable[..., Dict[str, ParallelStyle]]] = {
Qwen2ForCausalLM: _parallelize_qwen,
Qwen3ForCausalLM: _parallelize_qwen,
LlamaForCausalLM: _parallelize_llama,
# gemma-3-1b-it uses Gemma3ForCausalLM since it is a text-only model
Gemma3ForCausalLM: _parallelize_gemma3,
# The larger gemma models use Gemma3ForConditionalGeneration, which are for text-image input
Gemma3ForConditionalGeneration: _parallelize_gemma3,
}