nemo_rl.models.dtensor.parallelize#

Module Contents#

Classes#

RotaryEmbedParallel

Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.

Functions#

_parallelize_gemma3

Parallelizes a Gemma3ForCausalLM model across data parallel dimensions.

_parallelize_llama

Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.

_parallelize_qwen

Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions.

_parallelize_model

Parallelize a model using DTensor.

to_local_if_dtensor

Returns the local shard of the given tensor if it is a DTensor.

clip_grad_by_total_norm_

Clips gradient of an iterable of parameters by total norm.

get_grad_norm

Calculate the norm of gradients.

get_logprobs_from_vocab_parallel_logits

Computes log probabilities from vocabulary-parallel logits.

Data#

API#

class nemo_rl.models.dtensor.parallelize.RotaryEmbedParallel(
*,
sequence_dim: int = 1,
use_local_output: bool = False,
)[source]#

Bases: torch.distributed.tensor.parallel.SequenceParallel

Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.

Initialization

static _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh)[source]#
static _prepare_output_fn(use_local_output, mod, outputs, device_mesh)[source]#
nemo_rl.models.dtensor.parallelize._parallelize_gemma3(
model: Union[transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration],
dp_mesh: torch.distributed.device_mesh.DeviceMesh,
tp_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: torch.distributed.fsdp.MixedPrecisionPolicy,
offload_policy: torch.distributed.fsdp.OffloadPolicy,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
)[source]#

Parallelizes a Gemma3ForCausalLM model across data parallel dimensions.

Tensor parallelism is not supported for Gemma3 models because of tied word embeddings.

nemo_rl.models.dtensor.parallelize._parallelize_llama(
model: transformers.models.llama.modeling_llama.LlamaForCausalLM,
dp_mesh: torch.distributed.device_mesh.DeviceMesh,
tp_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: torch.distributed.fsdp.MixedPrecisionPolicy,
offload_policy: torch.distributed.fsdp.OffloadPolicy,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
)[source]#

Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.

nemo_rl.models.dtensor.parallelize._parallelize_qwen(
model: Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM],
dp_mesh: torch.distributed.device_mesh.DeviceMesh,
tp_mesh: torch.distributed.device_mesh.DeviceMesh,
mp_policy: torch.distributed.fsdp.MixedPrecisionPolicy,
offload_policy: torch.distributed.fsdp.OffloadPolicy,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
)[source]#

Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions.

nemo_rl.models.dtensor.parallelize.PARALLIZE_FUNCTIONS#

None

nemo_rl.models.dtensor.parallelize._parallelize_model(
model: Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.llama.modeling_llama.LlamaForCausalLM],
dp_mesh: torch.distributed.device_mesh.DeviceMesh,
tp_mesh: torch.distributed.device_mesh.DeviceMesh,
param_dtype: torch.dtype,
sequence_parallel: bool = False,
activation_checkpointing: bool = False,
cpu_offload: bool = False,
)[source]#

Parallelize a model using DTensor.

Parameters:
  • model (Union[Qwen2ForCausalLM, LlamaForCausalLM]) – The model to parallelize.

  • dp_mesh (DeviceMesh) – Device mesh for data parallelism.

  • tp_mesh (DeviceMesh) – Device mesh for tensor parallelism.

  • param_dtype (torch.dtype) – Data type for model parameters.

  • sequence_parallel (bool, optional) – Whether to use sequence parallelism. Defaults to False.

  • activation_checkpointing (bool, optional) – Whether to use activation checkpointing. Defaults to False.

  • cpu_offload (bool, optional) – Whether to enable cpu offloading for FSDP. Defaults to False.

Returns:

The parallelized model.

Raises:

ValueError – If the model type is not supported for parallelization.

nemo_rl.models.dtensor.parallelize.to_local_if_dtensor(
tensor: Union[torch.Tensor, torch.distributed.tensor.DTensor],
) torch.Tensor[source]#

Returns the local shard of the given tensor if it is a DTensor.

Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/605f618f237cda8fa80132bc2ccff933512d5a0d/megatron/core/utils.py#L746

nemo_rl.models.dtensor.parallelize.clip_grad_by_total_norm_(
parameters: Union[List[Union[torch.Tensor, torch.distributed.tensor.DTensor]], Union[torch.Tensor, torch.distributed.tensor.DTensor]],
max_grad_norm: Union[int, float],
total_norm: float,
dtype: torch.dtype = torch.float32,
)[source]#

Clips gradient of an iterable of parameters by total norm.

Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L138

Note that the gradients are modified in place.

Parameters:
  • parameters (Union[List[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]]) – An iterable of Tensors or DTensors, or a single Tensor or DTensor that will have gradients normalized.

  • max_grad_norm (Union[float, int]) – Maximum norm of the gradients.

  • total_norm (float) – The pre-computed total norm of the gradients to use for scaling.

nemo_rl.models.dtensor.parallelize.get_grad_norm(
parameters: Union[List[Union[torch.Tensor, torch.distributed.tensor.DTensor]], Union[torch.Tensor, torch.distributed.tensor.DTensor]],
dp_group: torch.distributed.ProcessGroup,
tp_group: torch.distributed.ProcessGroup,
norm_type: Union[int, float] = 2,
dtype: torch.dtype = torch.float32,
) float[source]#

Calculate the norm of gradients.

Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L51

Parameters:
  • parameters (Union[List[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]]) – An iterable of Tensors or DTensors, or a single Tensor or DTensor that will have gradient norm calculated.

  • dp_group (torch.distributed.ProcessGroup) – Process group for data parallel communication.

  • tp_group (torch.distributed.ProcessGroup) – Process group for tensor parallel communication.

  • norm_type (Union[int, float]) – Type of the used p-norm. Can be 'inf' for infinity norm.

Returns:

Total norm of the gradients (viewed as a single vector)

Return type:

float

nemo_rl.models.dtensor.parallelize.get_logprobs_from_vocab_parallel_logits(
vocab_parallel_logits: torch.distributed.tensor.DTensor,
input_ids: torch.Tensor,
)[source]#

Computes log probabilities from vocabulary-parallel logits.

This function takes logits that are sharded across the vocabulary dimension (tensor parallel) and computes the log probabilities for the given input IDs.

Parameters:
  • vocab_parallel_logits (DTensor) – Logits distributed across tensor parallel workers, with shape [batch_size, seq_len, vocab_size/tp_size].

  • input_ids (torch.Tensor) – Input token IDs for which to compute log probabilities, with shape [batch_size, seq_len].

Returns:

Log probabilities for the given input IDs.

Return type:

torch.Tensor