nemo_rl.models.dtensor.parallelize
#
Module Contents#
Classes#
Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple. |
Functions#
Parallelizes a Gemma3ForCausalLM model across data parallel dimensions. |
|
Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions. |
|
Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions. |
|
Parallelize a model using DTensor. |
|
Returns the local shard of the given tensor if it is a DTensor. |
|
Clips gradient of an iterable of parameters by total norm. |
|
Calculate the norm of gradients. |
|
Computes log probabilities from vocabulary-parallel logits. |
Data#
API#
- class nemo_rl.models.dtensor.parallelize.RotaryEmbedParallel(
- *,
- sequence_dim: int = 1,
- use_local_output: bool = False,
Bases:
torch.distributed.tensor.parallel.SequenceParallel
Custom SequenceParallel class for Qwen2 / Gemma3 rotary embeddings because the input is a tuple.
Initialization
- nemo_rl.models.dtensor.parallelize._parallelize_gemma3(
- model: Union[transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM, transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration],
- dp_mesh: torch.distributed.device_mesh.DeviceMesh,
- tp_mesh: torch.distributed.device_mesh.DeviceMesh,
- mp_policy: torch.distributed.fsdp.MixedPrecisionPolicy,
- offload_policy: torch.distributed.fsdp.OffloadPolicy,
- sequence_parallel: bool = False,
- activation_checkpointing: bool = False,
Parallelizes a Gemma3ForCausalLM model across data parallel dimensions.
Tensor parallelism is not supported for Gemma3 models because of tied word embeddings.
- nemo_rl.models.dtensor.parallelize._parallelize_llama(
- model: transformers.models.llama.modeling_llama.LlamaForCausalLM,
- dp_mesh: torch.distributed.device_mesh.DeviceMesh,
- tp_mesh: torch.distributed.device_mesh.DeviceMesh,
- mp_policy: torch.distributed.fsdp.MixedPrecisionPolicy,
- offload_policy: torch.distributed.fsdp.OffloadPolicy,
- sequence_parallel: bool = False,
- activation_checkpointing: bool = False,
Parallelizes a LlamaForCausalLM model across data and tensor parallel dimensions.
- nemo_rl.models.dtensor.parallelize._parallelize_qwen(
- model: Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM],
- dp_mesh: torch.distributed.device_mesh.DeviceMesh,
- tp_mesh: torch.distributed.device_mesh.DeviceMesh,
- mp_policy: torch.distributed.fsdp.MixedPrecisionPolicy,
- offload_policy: torch.distributed.fsdp.OffloadPolicy,
- sequence_parallel: bool = False,
- activation_checkpointing: bool = False,
Parallelizes a Qwen2ForCausalLM model across data and tensor parallel dimensions.
- nemo_rl.models.dtensor.parallelize.PARALLIZE_FUNCTIONS#
None
- nemo_rl.models.dtensor.parallelize._parallelize_model(
- model: Union[transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM, transformers.models.llama.modeling_llama.LlamaForCausalLM],
- dp_mesh: torch.distributed.device_mesh.DeviceMesh,
- tp_mesh: torch.distributed.device_mesh.DeviceMesh,
- param_dtype: torch.dtype,
- sequence_parallel: bool = False,
- activation_checkpointing: bool = False,
- cpu_offload: bool = False,
Parallelize a model using DTensor.
- Parameters:
model (Union[Qwen2ForCausalLM, LlamaForCausalLM]) – The model to parallelize.
dp_mesh (DeviceMesh) – Device mesh for data parallelism.
tp_mesh (DeviceMesh) – Device mesh for tensor parallelism.
param_dtype (torch.dtype) – Data type for model parameters.
sequence_parallel (bool, optional) – Whether to use sequence parallelism. Defaults to False.
activation_checkpointing (bool, optional) – Whether to use activation checkpointing. Defaults to False.
cpu_offload (bool, optional) – Whether to enable cpu offloading for FSDP. Defaults to False.
- Returns:
The parallelized model.
- Raises:
ValueError – If the model type is not supported for parallelization.
- nemo_rl.models.dtensor.parallelize.to_local_if_dtensor(
- tensor: Union[torch.Tensor, torch.distributed.tensor.DTensor],
Returns the local shard of the given tensor if it is a DTensor.
Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/605f618f237cda8fa80132bc2ccff933512d5a0d/megatron/core/utils.py#L746
- nemo_rl.models.dtensor.parallelize.clip_grad_by_total_norm_(
- parameters: Union[List[Union[torch.Tensor, torch.distributed.tensor.DTensor]], Union[torch.Tensor, torch.distributed.tensor.DTensor]],
- max_grad_norm: Union[int, float],
- total_norm: float,
- dtype: torch.dtype = torch.float32,
Clips gradient of an iterable of parameters by total norm.
Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L138
Note that the gradients are modified in place.
- Parameters:
parameters (Union[List[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]]) – An iterable of Tensors or DTensors, or a single Tensor or DTensor that will have gradients normalized.
max_grad_norm (Union[float, int]) – Maximum norm of the gradients.
total_norm (float) – The pre-computed total norm of the gradients to use for scaling.
- nemo_rl.models.dtensor.parallelize.get_grad_norm(
- parameters: Union[List[Union[torch.Tensor, torch.distributed.tensor.DTensor]], Union[torch.Tensor, torch.distributed.tensor.DTensor]],
- dp_group: torch.distributed.ProcessGroup,
- tp_group: torch.distributed.ProcessGroup,
- norm_type: Union[int, float] = 2,
- dtype: torch.dtype = torch.float32,
Calculate the norm of gradients.
Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L51
- Parameters:
parameters (Union[List[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]]) – An iterable of Tensors or DTensors, or a single Tensor or DTensor that will have gradient norm calculated.
dp_group (torch.distributed.ProcessGroup) – Process group for data parallel communication.
tp_group (torch.distributed.ProcessGroup) – Process group for tensor parallel communication.
norm_type (Union[int, float]) – Type of the used p-norm. Can be
'inf'
for infinity norm.
- Returns:
Total norm of the gradients (viewed as a single vector)
- Return type:
float
- nemo_rl.models.dtensor.parallelize.get_logprobs_from_vocab_parallel_logits(
- vocab_parallel_logits: torch.distributed.tensor.DTensor,
- input_ids: torch.Tensor,
Computes log probabilities from vocabulary-parallel logits.
This function takes logits that are sharded across the vocabulary dimension (tensor parallel) and computes the log probabilities for the given input IDs.
- Parameters:
vocab_parallel_logits (DTensor) – Logits distributed across tensor parallel workers, with shape [batch_size, seq_len, vocab_size/tp_size].
input_ids (torch.Tensor) – Input token IDs for which to compute log probabilities, with shape [batch_size, seq_len].
- Returns:
Log probabilities for the given input IDs.
- Return type:
torch.Tensor