nemo_rl.distributed.model_utils#
Module Contents#
Classes#
Custom autograd function for computing log probabilities in a distributed setting. |
|
Compute soft-target cross entropy across TP-sharded vocab. |
|
Custom autograd function for computing log probabilities in a distributed setting. |
|
Custom autograd function for computing log probabilities with top-k/top-p sampling. |
|
Chunked version of DistributedLogprobWithSampling for memory efficiency. |
|
Compute distributed log-softmax once and gather logprobs at given global indices. |
|
Compute H_all = sum_v p_v log p_v across TP with chunking over sequence. |
|
Compute distributed log-softmax once and gather logprobs at given global indices. |
Functions#
Compute a stable distributed log softmax across tensor parallel workers. |
|
Compute a stable distributed softmax across tensor parallel workers. |
|
Get log probabilities from TP+CP sharded vocab logits. |
|
Get log probabilities from TP+CP sharded vocab logits. |
|
Get log probabilities from TP sharded vocab logits for packed sequences. |
|
Get tokens on this context parallelism rank. |
|
Computes log probabilities from vocabulary-parallel logits. |
|
Compute token log-probabilities from logits, handling parallel and non-parallel cases. |
|
Compute global top-k over TP-sharded vocabulary logits. |
|
Gather student logits at given global token indices under TP+CP sharding. |
|
Compute top-k log probabilities from logits. |
|
Get log probabilities from TP sharded hidden states. |
|
Convert vocab-parallel logits to batch-sequence-parallel logits via all-to-all. |
|
Convert batch-sequence-parallel logits to vocab-parallel logits via all-to-all. |
API#
- nemo_rl.distributed.model_utils._compute_distributed_log_softmax(
- vocab_parallel_logits: torch.Tensor,
- group: torch.distributed.ProcessGroup,
Compute a stable distributed log softmax across tensor parallel workers.
Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265
- Parameters:
vocab_parallel_logits (torch.Tensor) – Logits tensor with shape [batch_size, seq_length, vocab_size//TP] where TP is the tensor parallel size.
group (torch.distributed.ProcessGroup) – Process group for the all-reduce operations.
- Returns:
Log softmax output with the same shape as input, but values represent log probabilities normalized across the full vocabulary dimension.
- Return type:
torch.Tensor
- nemo_rl.distributed.model_utils._compute_distributed_softmax(
- vocab_parallel_logits: torch.Tensor,
- group: torch.distributed.ProcessGroup,
Compute a stable distributed softmax across tensor parallel workers.
Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L239
- Parameters:
vocab_parallel_logits (torch.Tensor) – Logits tensor with shape [batch_size, seq_length, vocab_size//TP] where TP is the tensor parallel size.
group (torch.distributed.ProcessGroup) – Process group for the all-reduce operations.
- Returns:
Softmax output with the same shape as input, normalized across the full vocabulary.
- Return type:
torch.Tensor
- class nemo_rl.distributed.model_utils.DistributedLogprob#
Bases:
torch.autograd.FunctionCustom autograd function for computing log probabilities in a distributed setting.
Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286
- static forward(
- ctx: Any,
- vocab_parallel_logits: torch.Tensor,
- target: torch.Tensor,
- vocab_start_index: int,
- vocab_end_index: int,
- group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- static backward(
- ctx: Any,
- *grad_outputs: torch.Tensor,
- class nemo_rl.distributed.model_utils.DistributedCrossEntropy#
Bases:
torch.autograd.FunctionCompute soft-target cross entropy across TP-sharded vocab.
This returns H(p_target, q_student), which matches forward KL up to the target entropy constant. Backward propagates only through student logits.
- static forward(
- ctx: Any,
- student_logits: torch.Tensor,
- target_logits: torch.Tensor,
- group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- static backward(
- ctx: Any,
- *grad_outputs: torch.Tensor,
- class nemo_rl.distributed.model_utils.ChunkedDistributedLogprob#
Bases:
torch.autograd.FunctionCustom autograd function for computing log probabilities in a distributed setting.
The log probabilities computation is chunked in the sequence dimension to mitigate GPU OOM (especially during backward pass). In addition, logits casting from float16 or bfloat16 -> float32 is performed inside the chunk loop to avoid materializing a whole float32 logits tensor.
Adapted from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286
- static forward(
- ctx: Any,
- vocab_parallel_logits: torch.Tensor,
- target: torch.Tensor,
- vocab_start_index: int,
- vocab_end_index: int,
- chunk_size: int,
- tp_group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- static backward(
- ctx: Any,
- *grad_outputs: torch.Tensor,
- class nemo_rl.distributed.model_utils.DistributedLogprobWithSampling#
Bases:
torch.autograd.FunctionCustom autograd function for computing log probabilities with top-k/top-p sampling.
This function materializes the full vocabulary by converting from vocab-parallel to batch-sequence-parallel layout, applies filtering, and computes log probabilities.
- static forward(
- ctx: Any,
- vocab_parallel_logits: torch.Tensor,
- target: torch.Tensor,
- tp_group: torch.distributed.ProcessGroup,
- top_k: int | None,
- top_p: float,
- inference_only: bool = False,
Forward pass for sampling-based logprob computation.
- Parameters:
vocab_parallel_logits – [B, S, V_local] logits sharded by vocab
target – [B, S] target token ids (already shifted)
tp_group – Tensor parallel process group
top_k – Top-k filtering parameter (None or -1 to disable)
top_p – Top-p filtering parameter (1.0 to disable)
inference_only – If True, don’t save tensors for backward
- Returns:
Log probabilities [B, S]
- static backward(
- ctx: Any,
- *grad_outputs: torch.Tensor,
Backward pass for sampling-based logprob computation.
- class nemo_rl.distributed.model_utils.ChunkedDistributedLogprobWithSampling#
Bases:
torch.autograd.FunctionChunked version of DistributedLogprobWithSampling for memory efficiency.
Uses delayed rematerialization to avoid storing large intermediate tensors.
- static forward(
- ctx: Any,
- vocab_parallel_logits: torch.Tensor,
- target: torch.Tensor,
- tp_group: torch.distributed.ProcessGroup,
- top_k: int | None,
- top_p: float,
- chunk_size: int,
- inference_only: bool = False,
Forward pass with chunked processing.
- Parameters:
vocab_parallel_logits – [B, S, V_local] logits sharded by vocab
target – [B, S] target token ids (already shifted)
tp_group – Tensor parallel process group
top_k – Top-k filtering parameter (None or -1 to disable)
top_p – Top-p filtering parameter (1.0 to disable)
chunk_size – Chunk size for memory efficiency (in sequence dimension)
inference_only – If True, don’t save tensors for backward
- Returns:
Log probabilities [B, S]
- static backward(
- ctx: Any,
- *grad_outputs: torch.Tensor,
Backward pass with chunked rematerialization.
- class nemo_rl.distributed.model_utils.ChunkedDistributedGatherLogprob#
Bases:
torch.autograd.FunctionCompute distributed log-softmax once and gather logprobs at given global indices.
Forward computes per-chunk distributed log-softmax across TP, gathers selected log probabilities at the provided global indices (shape [B, S, K]), and returns a tensor of shape [B, S, K].
Backward recomputes per-chunk softmax from logits and applies the gradient rule: dL/dz = -softmax * sum_k(dL/dy_k) + scatter_add(dL/dy_k) over selected indices.
- static forward(
- ctx: Any,
- vocab_parallel_logits: torch.Tensor,
- global_indices: torch.Tensor,
- vocab_start_index: int,
- vocab_end_index: int,
- chunk_size: int,
- tp_group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- static backward(
- ctx: Any,
- *grad_outputs: torch.Tensor,
- nemo_rl.distributed.model_utils.dtensor_from_parallel_logits_to_logprobs(
- vocab_parallel_logits: torch.Tensor,
- target: torch.distributed.tensor.DTensor | torch.Tensor,
- vocab_start_index: int,
- vocab_end_index: int,
- tp_group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- seq_index: Optional[torch.Tensor] = None,
- chunk_size: Optional[int] = None,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
Get log probabilities from TP+CP sharded vocab logits.
- Parameters:
vocab_parallel_logits (orch.Tensor) – Logits distributed across tensor parallel workers, with shape [batch_size, seq_len, vocab_size/tp_size].
target (DTensor) – Target token indices with shape [batch_size, seq_len]. NOTE: Must be the unmodified targets as this function will shift them internally.
vocab_start_index (int) – Starting vocabulary index for this worker’s partition.
vocab_end_index (int) – Ending vocabulary index for this worker’s partition.
tp_group (torch.distributed.ProcessGroup) – Process group for distributed communication.
inference_only (bool, optional) – If True, tensors won’t be saved for backward pass. Defaults to False.
seq_index (Optional[torch.Tensor]) – Sequence index tensor with shape [seq_len]. It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension.
chunk_size (Optional[int]) – Sequence dimension chunk size for computing the log probabilities.
sampling_params (TrainingSamplingParams, optional) – Sampling parameters for Top-k/Top-p filtering and temperature scaling.
- Returns:
Log probabilities tensor with shape [batch_size, seq_len-1]. The sequence dimension is reduced by 1 due to the target shifting.
- Return type:
torch.Tensor
- nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs(
- vocab_parallel_logits: torch.Tensor,
- target: torch.Tensor,
- vocab_start_index: int,
- vocab_end_index: int,
- tp_group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- cp_group: Optional[torch.distributed.ProcessGroup] = None,
- chunk_size: Optional[int] = None,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
Get log probabilities from TP+CP sharded vocab logits.
- Parameters:
vocab_parallel_logits (torch.Tensor) – Logits tensor with shape [batch_size, seq_len // CP, vocab_size // TP] where TP is the tensor parallel size.
target (torch.Tensor) – Target token indices with shape [batch_size, seq_len]. NOTE: Must be the unmodified targets as this function will shift them internally.
vocab_start_index (int) – Starting vocabulary index for this worker’s partition.
vocab_end_index (int) – Ending vocabulary index for this worker’s partition.
tp_group (torch.distributed.ProcessGroup) – Process group for distributed communication.
inference_only (bool, optional) – If True, tensors won’t be saved for backward pass. Defaults to False.
cp_group (torch.distributed.ProcessGroup, optional) – Context parallelism process group. Defaults to None.
chunk_size (int, optional) – Sequence dimension chunk size for computing the log probabilities.
sampling_params (TrainingSamplingParams, optional) – Sampling parameters for Top-k/Top-p filtering and temperature scaling.
- Returns:
Log probabilities tensor with shape [batch_size, seq_len-1]. The sequence dimension is reduced by 1 due to the target shifting.
- Return type:
torch.Tensor
Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354
- nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs_packed_sequences(
- vocab_parallel_logits: torch.Tensor,
- target: torch.Tensor,
- cu_seqlens_padded: torch.Tensor,
- unpacked_seqlen: int,
- vocab_start_index: int,
- vocab_end_index: int,
- group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- cp_group: Optional[torch.distributed.ProcessGroup] = None,
- chunk_size: Optional[int] = None,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
- target_is_pre_rolled: bool = False,
Get log probabilities from TP sharded vocab logits for packed sequences.
- Parameters:
vocab_parallel_logits (torch.Tensor) – Packed logits tensor with shape [1, T // CP, vocab_size//TP] where T is the total number of tokens across all packed sequences.
target (torch.Tensor) – Packed target token indices. If target_is_pre_rolled=False: shape [1, T] — unmodified targets, rolled internally. If target_is_pre_rolled=True: shape [1, T // CP] — pre-rolled and pre-CP-sharded.
cu_seqlens_padded (torch.Tensor) – Cumulative sequence lengths tensor with shape [batch_size + 1]. cu_seqlens_padded[i] indicates the start position of sequence i in the packed format (full, not CP-adjusted).
unpacked_seqlen (int) – The length of the unpacked sequence tensor.
vocab_start_index (int) – Starting vocabulary index for this worker’s partition.
vocab_end_index (int) – Ending vocabulary index for this worker’s partition.
group (torch.distributed.ProcessGroup) – Process group for distributed communication.
inference_only (bool, optional) – If True, tensors won’t be saved for backward pass. Defaults to False.
cp_group (torch.distributed.ProcessGroup, optional) – Context parallelism process group. Defaults to None.
chunk_size (int, optional) – Sequence dimension chunk size for computing the log probabilities.
sampling_params (TrainingSamplingParams, optional) – Sampling parameters for Top-k/Top-p filtering.
target_is_pre_rolled (bool) – If True, target is already shifted and CP-sharded to match vocab_parallel_logits shape, skipping the internal per-sequence roll+CP-shard loop.
- Returns:
Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. The total length is reduced by batch_size due to target shifting (one token per sequence).
- Return type:
torch.Tensor
- nemo_rl.distributed.model_utils._get_tokens_on_this_cp_rank(
- input_ids: torch.Tensor,
- cp_rank: int,
- cp_size: int,
- seq_dim: int = 1,
Get tokens on this context parallelism rank.
Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1.
- Parameters:
input_ids – Input token IDs [seq_length, ]
cp_rank – Context parallelism rank
cp_size – Context parallelism size
- Returns:
Tokens on this context parallelism rank [1, seq_length // cp_size]
- nemo_rl.distributed.model_utils.allgather_cp_sharded_tensor(tensor, cp_group, seq_dim=1)#
- class nemo_rl.distributed.model_utils.AllGatherCPTensor#
Bases:
torch.autograd.Function- forward(tensor, cp_group: torch.distributed.ProcessGroup, seq_dim=1)#
- backward(grad_output)#
- nemo_rl.distributed.model_utils.get_logprobs_from_vocab_parallel_logits(
- vocab_parallel_logits: torch.distributed.tensor.DTensor,
- input_ids: torch.Tensor | torch.distributed.tensor.DTensor,
- seq_index: Optional[torch.Tensor] = None,
- chunk_size: Optional[int] = None,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
Computes log probabilities from vocabulary-parallel logits.
This function takes logits that are sharded across the vocabulary dimension (tensor parallel) and computes the log probabilities for the given input IDs.
- Parameters:
vocab_parallel_logits (DTensor) – Logits distributed across tensor parallel workers, with shape [batch_size, seq_len, vocab_size/tp_size].
input_ids (torch.Tensor | DTensor) – Input token IDs for which to compute log probabilities, with shape [batch_size, seq_len].
seq_index (Optional[torch.Tensor]) – Sequence index for the input IDs, with shape [sequence_length].
chunk_size (Optional[int]) – Sequence dimension chunk size for computing log probabilities.
sampling_params (TrainingSamplingParams, optional) – Sampling parameters for Top-k/Top-p filtering and temperature scaling.
- Returns:
Log probabilities for the given input IDs.
- Return type:
torch.Tensor
- nemo_rl.distributed.model_utils.get_next_token_logprobs_from_logits(
- input_ids: torch.Tensor,
- next_token_logits: torch.Tensor,
- seq_index: Optional[torch.Tensor] = None,
- vocab_parallel_rank: Optional[int] = None,
- vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
- context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams] = None,
Compute token log-probabilities from logits, handling parallel and non-parallel cases.
This function handles three cases:
Vocab parallel (Megatron-style): uses from_parallel_logits_to_logprobs
DTensor: uses get_logprobs_from_vocab_parallel_logits
Non-parallel: applies top-k/top-p filtering, log_softmax, and gather
- Parameters:
input_ids – Input token IDs of shape [batch_size, seq_len]
next_token_logits – Logits tensor of shape [batch_size, seq_len, vocab_size]
seq_index – Sequence index tensor for DTensor path
vocab_parallel_rank – Rank in the vocab parallel group (required if vocab_parallel_group is provided)
vocab_parallel_group – Process group for vocab parallelism
context_parallel_group – Process group for context parallelism
sampling_params – Sampling parameters for top-k/top-p filtering
- Returns:
Token log-probabilities of shape [batch_size, seq_len - 1]
- nemo_rl.distributed.model_utils.distributed_vocab_topk(
- vocab_parallel_logits: torch.Tensor,
- k: int,
- tp_group: torch.distributed.ProcessGroup,
- *,
- vocab_start_index: int,
- vocab_end_index: int,
- chunk_size: Optional[int] = None,
Compute global top-k over TP-sharded vocabulary logits.
- Parameters:
vocab_parallel_logits – [B, S, V_local]
k – number of top tokens to select globally
tp_group – tensor-parallel process group
vocab_start_index – global vocab start for this rank (inclusive)
vocab_end_index – global vocab end for this rank (exclusive)
chunk_size – optional chunk along sequence dim to bound memory
- Returns:
[B, S, k] topk_global_indices: [B, S, k] (global token ids)
- Return type:
topk_vals
- nemo_rl.distributed.model_utils.gather_logits_at_global_indices(
- vocab_parallel_logits: torch.Tensor,
- global_indices: torch.Tensor,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
- cp_group: Optional[torch.distributed.ProcessGroup] = None,
- *,
- vocab_start_index: int,
- vocab_end_index: int,
- chunk_size: Optional[int] = None,
Gather student logits at given global token indices under TP+CP sharding.
Differentiable w.r.t. vocab_parallel_logits.
- Parameters:
vocab_parallel_logits – [B, S_cp, V_local] where S_cp is CP sharded sequence length
global_indices – [B, S_full, k] where S_full is full sequence length
tp_group – Optional tensor-parallel process group. If None, treats logits as full-vocab (no TP) and skips TP all-reduce.
vocab_start_index – global vocab start for this rank (inclusive)
vocab_end_index – global vocab end for this rank (exclusive)
chunk_size – optional chunk along sequence dim to bound memory
cp_group – Optional context-parallel process group
- Returns:
[B, S_full, k]
- Return type:
gathered_logits
- nemo_rl.distributed.model_utils.get_distillation_topk_logprobs_from_logits(
- student_logits: torch.Tensor,
- teacher_topk_logits: torch.Tensor,
- teacher_topk_indices: torch.Tensor,
- zero_outside_topk: bool,
- calculate_entropy: bool,
- vocab_parallel_rank: Optional[int] = None,
- vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
- context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
Compute top-k log probabilities from logits.
- class nemo_rl.distributed.model_utils.ChunkedDistributedEntropy#
Bases:
torch.autograd.FunctionCompute H_all = sum_v p_v log p_v across TP with chunking over sequence.
Forward returns [B, S] tensor of global entropy; backward propagates through logits.
- static forward(
- ctx: Any,
- vocab_parallel_logits: torch.Tensor,
- chunk_size: int,
- tp_group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- static backward(
- ctx: Any,
- *grad_outputs: torch.Tensor,
- tensor_parallel_hidden_states: torch.Tensor,
- output_weight_layer: torch.Tensor,
- output_weight: torch.Tensor,
- runtime_gather_output: bool,
- target: torch.Tensor,
- vocab_start_index: int,
- vocab_end_index: int,
- tp_group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- cp_group: Optional[torch.distributed.ProcessGroup] = None,
- chunk_size: Optional[int] = None,
Get log probabilities from TP sharded hidden states.
- class nemo_rl.distributed.model_utils.ChunkedDistributedHiddenStatesToLogprobs#
Bases:
torch.autograd.FunctionCompute distributed log-softmax once and gather logprobs at given global indices.
- static forward(
- ctx: Any,
- tensor_parallel_hidden_states: torch.Tensor,
- target: torch.Tensor,
- output_weight_layer: torch.Tensor,
- vocab_start_index: int,
- vocab_end_index: int,
- chunk_size: int,
- tp_group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- static backward(
- ctx: Any,
- *grad_outputs: torch.Tensor,
- nemo_rl.distributed.model_utils.patch_gpt_model_forward_for_linear_ce_fusion(
- *,
- chunk_size: int,
- nemo_rl.distributed.model_utils._gpt_forward_with_linear_ce_fusion(
- self: megatron.core.models.gpt.GPTModel,
- input_ids: torch.Tensor,
- position_ids: torch.Tensor,
- attention_mask: torch.Tensor,
- decoder_input: torch.Tensor = None,
- labels: torch.Tensor = None,
- inference_context: Any = None,
- packed_seq_params: Any = None,
- extra_block_kwargs: Optional[dict] = None,
- runtime_gather_output: Optional[bool] = None,
- *,
- inference_params: Optional[Any] = None,
- loss_mask: Optional[torch.Tensor] = None,
- padding_mask: Optional[torch.Tensor] = None,
- return_logprobs_for_linear_ce_fusion: bool = False,
- nemo_rl.distributed.model_utils.all_to_all_vp2sq(
- vocab_parallel_logits: torch.Tensor,
- tp_group: torch.distributed.ProcessGroup,
Convert vocab-parallel logits to batch-sequence-parallel logits via all-to-all.
Note: This partitions the flattened B*S dimension, not just S. The input vocab_parallel_logits need to be 2D tensor.
Transforms [BS, V_local] -> [BS_local, V] where:
V_local = V / tp_size (vocab is sharded)
BS_local = BS / tp_size (batch-sequence will be sharded)
Requires BS to be divisible by tp_size
- Parameters:
vocab_parallel_logits – [BS, V_local] tensor with vocab dimension sharded
tp_group – Tensor parallel process group
- Returns:
Batch-sequence-parallel logits [BS_local, V] with batch-sequence dimension sharded
- nemo_rl.distributed.model_utils.all_to_all_sq2vp(
- seq_parallel_logits: torch.Tensor,
- tp_group: torch.distributed.ProcessGroup,
Convert batch-sequence-parallel logits to vocab-parallel logits via all-to-all.
Inverse operation of all_to_all_vp2sq.
Transforms [BS_local, V] -> [BS, V_local] where:
BS_local = BS / tp_size (batch-sequence is sharded)
V_local = V / tp_size (vocab will be sharded)
- Parameters:
seq_parallel_logits – [BS_local, V] tensor with batch-sequence dimension sharded
tp_group – Tensor parallel process group
- Returns:
Vocab-parallel logits [BS, V_local] with vocab dimension sharded