# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import torch
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
PrepareModuleOutput,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
from torch.distributed.tensor.placement_types import Replicate, Shard
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from nemo_rl.distributed.model_utils import from_parallel_logits_to_logprobs
[docs]
class RotaryEmbedParallel(SequenceParallel):
"""Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple."""
[docs]
@staticmethod
def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
return type(outputs)([o.to_local() if use_local_output else o for o in outputs])
[docs]
def _parallelize_gemma3(
model: Union[Gemma3ForCausalLM, Gemma3ForConditionalGeneration],
dp_mesh: DeviceMesh,
tp_mesh: DeviceMesh,
mp_policy: MixedPrecisionPolicy,
offload_policy: torch.distributed.fsdp.OffloadPolicy,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
):
"""Parallelizes a Gemma3ForCausalLM model across data parallel dimensions.
Tensor parallelism is not supported for Gemma3 models because of tied word embeddings.
"""
if isinstance(model, Gemma3ForConditionalGeneration):
layers = model.language_model.model.layers
model_prefix = "language_model.model"
num_attention_heads = model.config.text_config.num_attention_heads
num_key_value_heads = model.config.text_config.num_key_value_heads
else:
layers = model.model.layers
model_prefix = "model"
num_attention_heads = model.config.num_attention_heads
num_key_value_heads = model.config.num_key_value_heads
if tp_mesh.size() > 1:
assert num_key_value_heads % tp_mesh.size() == 0, (
f"num_key_value_heads ({num_key_value_heads}) must be divisible by TP size ({tp_mesh.size()})"
)
assert num_attention_heads % tp_mesh.size() == 0, (
f"num_attention_heads ({num_attention_heads}) must be divisible by TP size ({tp_mesh.size()})"
)
# For gemma3 models, we don't include the model.embed_tokens and lm_head in the
# parallelization plans because they have tied weights.
base_model_tp_plan = {
f"{model_prefix}.layers.*.self_attn.q_proj": ColwiseParallel(),
f"{model_prefix}.layers.*.self_attn.k_proj": ColwiseParallel(),
f"{model_prefix}.layers.*.self_attn.v_proj": ColwiseParallel(),
f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel(),
f"{model_prefix}.layers.*.mlp.up_proj": ColwiseParallel(),
f"{model_prefix}.layers.*.mlp.gate_proj": ColwiseParallel(),
f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel(),
}
base_model_sp_plan = {
f"{model_prefix}.embed_tokens": PrepareModuleOutput(
output_layouts=Replicate(),
desired_output_layouts=Shard(1),
use_local_output=False,
),
f"{model_prefix}.rotary_emb": RotaryEmbedParallel(use_local_output=True),
f"{model_prefix}.rotary_emb_local": RotaryEmbedParallel(
use_local_output=True
),
f"{model_prefix}.layers.*.input_layernorm": SequenceParallel(),
f"{model_prefix}.layers.*.self_attn.o_proj": RowwiseParallel(
output_layouts=Shard(1)
),
f"{model_prefix}.layers.*.post_attention_layernorm": SequenceParallel(),
f"{model_prefix}.layers.*.pre_feedforward_layernorm": SequenceParallel(),
f"{model_prefix}.layers.*.mlp.down_proj": RowwiseParallel(
output_layouts=Shard(1)
),
f"{model_prefix}.layers.*.post_feedforward_layernorm": SequenceParallel(),
f"{model_prefix}.norm": SequenceParallel(),
f"{model_prefix}.lm_head": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
use_local_output=True,
),
}
if sequence_parallel:
# Enable sequence parallelism only if TP size > 1
base_model_tp_plan.update(base_model_sp_plan)
parallelize_module(model, tp_mesh, base_model_tp_plan)
if activation_checkpointing:
for i in range(len(layers)):
layers[i].mlp = checkpoint_wrapper(layers[i].mlp)
for layer in layers:
fully_shard(
layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy
)
return fully_shard(
model, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy
)
[docs]
def _parallelize_llama(
model: LlamaForCausalLM,
dp_mesh: DeviceMesh,
tp_mesh: DeviceMesh,
mp_policy: MixedPrecisionPolicy,
offload_policy: torch.distributed.fsdp.OffloadPolicy,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
):
"""Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions."""
if tp_mesh.size() > 1:
assert not model.config.tie_word_embeddings, (
"Tie word embeddings not supported when TP is enabled"
)
base_model_tp_plan = {
"model.embed_tokens": RowwiseParallel(input_layouts=Replicate()),
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
"model.layers.*.self_attn.o_proj": RowwiseParallel(),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(),
"lm_head": ColwiseParallel(
output_layouts=Shard(-1), use_local_output=False
),
}
base_model_sp_plan = {
"model.embed_tokens": RowwiseParallel(
input_layouts=Replicate(), output_layouts=Shard(1)
),
"model.norm": SequenceParallel(),
"model.layers.*.input_layernorm": SequenceParallel(),
"model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
"model.layers.*.post_attention_layernorm": SequenceParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
"lm_head": ColwiseParallel(
input_layouts=Shard(1), output_layouts=Shard(-1), use_local_output=False
),
}
if sequence_parallel:
# Enable sequence parallelism only if TP size > 1
base_model_tp_plan.update(base_model_sp_plan)
parallelize_module(model, tp_mesh, base_model_tp_plan)
if activation_checkpointing:
for i in range(len(model.model.layers)):
model.model.layers[i].mlp = checkpoint_wrapper(model.model.layers[i].mlp)
for layer in model.model.layers:
fully_shard(
layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy
)
return fully_shard(
model, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy
)
[docs]
def _parallelize_qwen(
model: Union[Qwen2ForCausalLM, Qwen3ForCausalLM],
dp_mesh: DeviceMesh,
tp_mesh: DeviceMesh,
mp_policy: MixedPrecisionPolicy,
offload_policy: torch.distributed.fsdp.OffloadPolicy,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
):
"""Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions."""
class Qwen3QKNorm(SequenceParallel):
@staticmethod
def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
input_tensor = inputs[0]
if isinstance(input_tensor, DTensor):
assert input_tensor.placements == (Shard(dim=2),)
elif isinstance(input_tensor, torch.Tensor):
# assume the input passed in already sharded on the sequence dim and create the DTensor
return DTensor.from_local(
input_tensor, device_mesh, sequence_sharding, run_check=False
)
else:
raise ValueError(
f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}"
)
if tp_mesh.size() > 1:
assert not model.config.tie_word_embeddings, (
"Tie word embeddings not supported when TP is enabled"
)
if sequence_parallel:
base_model_tp_plan = {
"lm_head": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1),
use_local_output=False,
),
"model.embed_tokens": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"model.rotary_emb": RotaryEmbedParallel(),
"model.norm": SequenceParallel(),
"model.layers.*.input_layernorm": SequenceParallel(),
"model.layers.*.self_attn.q_proj": ColwiseParallel(
use_local_output=False
),
"model.layers.*.self_attn.k_proj": ColwiseParallel(
use_local_output=False
),
"model.layers.*.self_attn.v_proj": ColwiseParallel(
use_local_output=False
),
"model.layers.*.self_attn.o_proj": RowwiseParallel(
output_layouts=Shard(1)
),
"model.layers.*.self_attn.q_norm": Qwen3QKNorm(),
"model.layers.*.self_attn.k_norm": Qwen3QKNorm(),
"model.layers.*.post_attention_layernorm": SequenceParallel(),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(
output_layouts=Shard(1)
),
}
else:
base_model_tp_plan = {
"lm_head": ColwiseParallel(
output_layouts=Shard(-1), use_local_output=False
),
"model.embed_tokens": RowwiseParallel(
input_layouts=Replicate(),
),
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
"model.layers.*.self_attn.o_proj": RowwiseParallel(),
"model.layers.*.mlp.up_proj": ColwiseParallel(),
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
"model.layers.*.mlp.down_proj": RowwiseParallel(),
}
parallelize_module(model, tp_mesh, base_model_tp_plan)
if activation_checkpointing:
for i in range(len(model.model.layers)):
model.model.layers[i].mlp = checkpoint_wrapper(model.model.layers[i].mlp)
for layer in model.model.layers:
fully_shard(
layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy
)
return fully_shard(
model, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy
)
PARALLIZE_FUNCTIONS = {
Qwen2ForCausalLM: _parallelize_qwen,
Qwen3ForCausalLM: _parallelize_qwen,
LlamaForCausalLM: _parallelize_llama,
# gemma-3-1b-it uses Gemma3ForCausalLM since it is a text-only model
Gemma3ForCausalLM: _parallelize_gemma3,
# The larger gemma models use Gemma3ForConditionalGeneration, which are for text-image input
Gemma3ForConditionalGeneration: _parallelize_gemma3,
}
[docs]
def _parallelize_model(
model: Union[Qwen2ForCausalLM, LlamaForCausalLM],
dp_mesh: DeviceMesh,
tp_mesh: DeviceMesh,
param_dtype: torch.dtype,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
cpu_offload: bool = False,
):
"""Parallelize a model using DTensor.
Args:
model (Union[Qwen2ForCausalLM, LlamaForCausalLM]): The model to parallelize.
dp_mesh (DeviceMesh): Device mesh for data parallelism.
tp_mesh (DeviceMesh): Device mesh for tensor parallelism.
param_dtype (torch.dtype): Data type for model parameters.
sequence_parallel (bool, optional): Whether to use sequence parallelism. Defaults to False.
activation_checkpointing (bool, optional): Whether to use activation checkpointing. Defaults to False.
cpu_offload (bool, optional): Whether to enable cpu offloading for FSDP. Defaults to False.
Returns:
The parallelized model.
Raises:
ValueError: If the model type is not supported for parallelization.
"""
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype,
reduce_dtype=torch.float32,
output_dtype=torch.float32,
)
offload_policy = (
CPUOffloadPolicy(pin_memory=False)
if cpu_offload
else torch.distributed.fsdp.OffloadPolicy
)
model_cls = type(model)
if model_cls not in PARALLIZE_FUNCTIONS:
raise ValueError(f"Model {model_cls} not supported as part of dtensor")
func = PARALLIZE_FUNCTIONS[type(model)]
return func(
model,
dp_mesh,
tp_mesh,
mp_policy,
offload_policy,
sequence_parallel,
activation_checkpointing,
)
[docs]
def to_local_if_dtensor(tensor: Union[torch.Tensor, DTensor]) -> torch.Tensor:
"""Returns the local shard of the given tensor if it is a DTensor.
Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/605f618f237cda8fa80132bc2ccff933512d5a0d/megatron/core/utils.py#L746
"""
with torch.no_grad():
return tensor.to_local() if isinstance(tensor, DTensor) else tensor
[docs]
def clip_grad_by_total_norm_(
parameters: Union[List[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]],
max_grad_norm: Union[int, float],
total_norm: float,
dtype: torch.dtype = torch.float32,
):
"""Clips gradient of an iterable of parameters by total norm.
Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L138
Note that the gradients are modified in place.
Args:
parameters (Union[List[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]]):
An iterable of Tensors or DTensors, or a single Tensor or DTensor
that will have gradients normalized.
max_grad_norm (Union[float, int]): Maximum norm of the gradients.
total_norm (float): The pre-computed total norm of the gradients to use for scaling.
"""
if isinstance(parameters, (torch.Tensor, DTensor)):
parameters = [parameters]
# Grads.
grads = [
to_local_if_dtensor(p.grad.detach()).to(dtype)
for p in parameters
if p.grad is not None
]
# Scale.
clip_coeff = max_grad_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
for g in grads:
g.mul_(clip_coeff)
[docs]
def get_grad_norm(
parameters: Union[List[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]],
dp_group: torch.distributed.ProcessGroup,
tp_group: torch.distributed.ProcessGroup,
norm_type: Union[int, float] = 2,
dtype: torch.dtype = torch.float32,
) -> float:
"""Calculate the norm of gradients.
Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L51
Args:
parameters (Union[List[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]]):
An iterable of Tensors or DTensors, or a single Tensor or DTensor
that will have gradient norm calculated.
dp_group (torch.distributed.ProcessGroup): Process group for data parallel communication.
tp_group (torch.distributed.ProcessGroup): Process group for tensor parallel communication.
norm_type (Union[int, float]): Type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
float: Total norm of the gradients (viewed as a single vector)
"""
if isinstance(parameters, (torch.Tensor, DTensor)):
parameters = [parameters]
# Grads.
grads_for_norm = [
to_local_if_dtensor(p.grad.detach()).to(dtype)
for p in parameters
if p.grad is not None
]
# Norm parameters.
norm_type = float(norm_type)
total_norm = 0.0
# Calculate norm.
if norm_type == torch.inf:
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.tensor(
[float(total_norm)], dtype=torch.float, device="cuda"
)
# Take max across all data-parallel GPUs if using FSDP and then all model-parallel GPUs.
torch.distributed.all_reduce(
total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group
)
torch.distributed.all_reduce(
total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=tp_group
)
total_norm = total_norm_cuda[0].item()
else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm**norm_type
total_norm = total_norm.cuda()
# Sum across all data-parallel GPUs if using FSDP and then all model-parallel GPUs.
torch.distributed.all_reduce(
total_norm, op=torch.distributed.ReduceOp.SUM, group=dp_group
)
torch.distributed.all_reduce(
total_norm, op=torch.distributed.ReduceOp.SUM, group=tp_group
)
total_norm = total_norm.item() ** (1.0 / norm_type)
return total_norm
[docs]
def get_logprobs_from_vocab_parallel_logits(
vocab_parallel_logits: DTensor, input_ids: torch.Tensor
):
"""Computes log probabilities from vocabulary-parallel logits.
This function takes logits that are sharded across the vocabulary dimension (tensor parallel)
and computes the log probabilities for the given input IDs.
Args:
vocab_parallel_logits (DTensor): Logits distributed across tensor parallel workers,
with shape [batch_size, seq_len, vocab_size/tp_size].
input_ids (torch.Tensor): Input token IDs for which to compute log probabilities,
with shape [batch_size, seq_len].
Returns:
torch.Tensor: Log probabilities for the given input IDs.
"""
tp_mesh = vocab_parallel_logits.device_mesh
tp_rank: int = tp_mesh.get_local_rank()
vocab_interval_per_rank = vocab_parallel_logits.shape[-1] // tp_mesh.size()
return from_parallel_logits_to_logprobs(
vocab_parallel_logits.to_local(),
input_ids,
vocab_interval_per_rank * tp_rank,
(tp_rank + 1) * vocab_interval_per_rank,
tp_mesh.get_group(),
inference_only=not torch.is_grad_enabled(),
)