nemo_rl.distributed.model_utils#

Module Contents#

Classes#

DistributedLogprob

Custom autograd function for computing log probabilities in a distributed setting.

ChunkedDistributedLogprob

Custom autograd function for computing log probabilities in a distributed setting.

ChunkedDistributedGatherLogprob

Compute distributed log-softmax once and gather logprobs at given global indices.

AllGatherCPTensor

ChunkedDistributedEntropy

Compute H_all = sum_v p_v log p_v across TP with chunking over sequence.

Functions#

_compute_distributed_log_softmax

Compute a stable distributed log softmax across tensor parallel workers.

dtensor_from_parallel_logits_to_logprobs

Get log probabilities from TP+CP sharded vocab logits.

from_parallel_logits_to_logprobs

Get log probabilities from TP+CP sharded vocab logits.

from_parallel_logits_to_logprobs_packed_sequences

Get log probabilities from TP sharded vocab logits for packed sequences.

_get_tokens_on_this_cp_rank

Get tokens on this context parallelism rank.

allgather_cp_sharded_tensor

get_logprobs_from_vocab_parallel_logits

Computes log probabilities from vocabulary-parallel logits.

distributed_vocab_topk

Compute global top-k over TP-sharded vocabulary logits.

gather_logits_at_global_indices

Gather student logits at given global token indices under TP+CP sharding.

API#

nemo_rl.distributed.model_utils._compute_distributed_log_softmax(
vocab_parallel_logits: torch.Tensor,
group: torch.distributed.ProcessGroup,
) torch.Tensor#

Compute a stable distributed log softmax across tensor parallel workers.

Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265

Parameters:
  • vocab_parallel_logits (torch.Tensor) – Logits tensor with shape [batch_size, seq_length, vocab_size//TP] where TP is the tensor parallel size.

  • group (torch.distributed.ProcessGroup) – Process group for the all-reduce operations.

Returns:

Log softmax output with the same shape as input, but values represent log probabilities normalized across the full vocabulary dimension.

Return type:

torch.Tensor

class nemo_rl.distributed.model_utils.DistributedLogprob#

Bases: torch.autograd.Function

Custom autograd function for computing log probabilities in a distributed setting.

Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286

static forward(
ctx: Any,
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
group: torch.distributed.ProcessGroup,
inference_only: bool = False,
) torch.Tensor#
static backward(
ctx: Any,
*grad_outputs: torch.Tensor,
) tuple[torch.Tensor, None, None, None, None, None, None]#
class nemo_rl.distributed.model_utils.ChunkedDistributedLogprob#

Bases: torch.autograd.Function

Custom autograd function for computing log probabilities in a distributed setting.

The log probabilities computation is chunked in the sequence dimension to mitigate GPU OOM (especially during backward pass). In addition, logits casting from float16 or bfloat16 -> float32 is performed inside the chunk loop to avoid materializing a whole float32 logits tensor.

Adapted from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286

static forward(
ctx: Any,
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
chunk_size: int,
tp_group: torch.distributed.ProcessGroup,
inference_only: bool = False,
) torch.Tensor#
static backward(
ctx: Any,
*grad_outputs: torch.Tensor,
) tuple[torch.Tensor, None, None, None, None, None, None]#
class nemo_rl.distributed.model_utils.ChunkedDistributedGatherLogprob#

Bases: torch.autograd.Function

Compute distributed log-softmax once and gather logprobs at given global indices.

Forward computes per-chunk distributed log-softmax across TP, gathers selected log probabilities at the provided global indices (shape [B, S, K]), and returns a tensor of shape [B, S, K].

Backward recomputes per-chunk softmax from logits and applies the gradient rule: dL/dz = -softmax * sum_k(dL/dy_k) + scatter_add(dL/dy_k) over selected indices.

static forward(
ctx: Any,
vocab_parallel_logits: torch.Tensor,
global_indices: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
chunk_size: int,
tp_group: torch.distributed.ProcessGroup,
inference_only: bool = False,
) torch.Tensor#
static backward(
ctx: Any,
*grad_outputs: torch.Tensor,
) tuple[torch.Tensor, None, None, None, None, None, None]#
nemo_rl.distributed.model_utils.dtensor_from_parallel_logits_to_logprobs(
vocab_parallel_logits: torch.Tensor,
target: torch.distributed.tensor.DTensor | torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
tp_group: torch.distributed.ProcessGroup,
inference_only: bool = False,
seq_index: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) torch.Tensor#

Get log probabilities from TP+CP sharded vocab logits.

Parameters:
  • vocab_parallel_logits (orch.Tensor) – Logits distributed across tensor parallel workers, with shape [batch_size, seq_len, vocab_size/tp_size].

  • target (DTensor) – Target token indices with shape [batch_size, seq_len]. NOTE: Must be the unmodified targets as this function will shift them internally.

  • vocab_start_index (int) – Starting vocabulary index for this worker’s partition.

  • vocab_end_index (int) – Ending vocabulary index for this worker’s partition.

  • tp_group (torch.distributed.ProcessGroup) – Process group for distributed communication.

  • inference_only (bool, optional) – If True, tensors won’t be saved for backward pass. Defaults to False.

  • seq_index (Optional[torch.Tensor]) – Sequence index tensor with shape [seq_len]. It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension.

  • chunk_size (Optional[int]) – Sequence dimension chunk size for computing the log probabilities.

Returns:

Log probabilities tensor with shape [batch_size, seq_len-1]. The sequence dimension is reduced by 1 due to the target shifting.

Return type:

torch.Tensor

nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
tp_group: torch.distributed.ProcessGroup,
inference_only: bool = False,
cp_group: Optional[torch.distributed.ProcessGroup] = None,
chunk_size: Optional[int] = None,
) torch.Tensor#

Get log probabilities from TP+CP sharded vocab logits.

Parameters:
  • vocab_parallel_logits (torch.Tensor) – Logits tensor with shape [batch_size, seq_len // CP, vocab_size // TP] where TP is the tensor parallel size.

  • target (torch.Tensor) – Target token indices with shape [batch_size, seq_len]. NOTE: Must be the unmodified targets as this function will shift them internally.

  • vocab_start_index (int) – Starting vocabulary index for this worker’s partition.

  • vocab_end_index (int) – Ending vocabulary index for this worker’s partition.

  • tp_group (torch.distributed.ProcessGroup) – Process group for distributed communication.

  • inference_only (bool, optional) – If True, tensors won’t be saved for backward pass. Defaults to False.

  • cp_group (torch.distributed.ProcessGroup, optional) – Context parallelism process group. Defaults to None.

  • chunk_size (int, optional) – Sequence dimension chunk size for computing the log probabilities.

Returns:

Log probabilities tensor with shape [batch_size, seq_len-1]. The sequence dimension is reduced by 1 due to the target shifting.

Return type:

torch.Tensor

Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354

nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs_packed_sequences(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
cu_seqlens_padded: torch.Tensor,
unpacked_seqlen: int,
vocab_start_index: int,
vocab_end_index: int,
group: torch.distributed.ProcessGroup,
inference_only: bool = False,
cp_group: Optional[torch.distributed.ProcessGroup] = None,
chunk_size: Optional[int] = None,
) torch.Tensor#

Get log probabilities from TP sharded vocab logits for packed sequences.

Parameters:
  • vocab_parallel_logits (torch.Tensor) – Packed logits tensor with shape [1, T // CP, vocab_size//TP] where T is the total number of tokens across all packed sequences.

  • target (torch.Tensor) – Packed target token indices with shape [1, T]. NOTE: Must be the unmodified targets as this function will shift them internally.

  • cu_seqlens (torch.Tensor) – Cumulative sequence lengths tensor with shape [batch_size + 1]. cu_seqlens[i] indicates the start position of sequence i in the packed format.

  • unpacked_seqlen (int) – The length of the unpacked sequence tensor.

  • vocab_start_index (int) – Starting vocabulary index for this worker’s partition.

  • vocab_end_index (int) – Ending vocabulary index for this worker’s partition.

  • group (torch.distributed.ProcessGroup) – Process group for distributed communication.

  • inference_only (bool, optional) – If True, tensors won’t be saved for backward pass. Defaults to False.

  • cp_group (torch.distributed.ProcessGroup, optional) – Context parallelism process group. Defaults to None.

  • chunk_size (int, optional) – Sequence dimension chunk size for computing the log probabilities.

Returns:

Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. The total length is reduced by batch_size due to target shifting (one token per sequence).

Return type:

torch.Tensor

nemo_rl.distributed.model_utils._get_tokens_on_this_cp_rank(
input_ids: torch.Tensor,
cp_rank: int,
cp_size: int,
seq_dim: int = 1,
) torch.Tensor#

Get tokens on this context parallelism rank.

Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1.

Parameters:
  • input_ids – Input token IDs [seq_length, ]

  • cp_rank – Context parallelism rank

  • cp_size – Context parallelism size

Returns:

Tokens on this context parallelism rank [1, seq_length // cp_size]

nemo_rl.distributed.model_utils.allgather_cp_sharded_tensor(tensor, cp_group, seq_dim=1)#
class nemo_rl.distributed.model_utils.AllGatherCPTensor#

Bases: torch.autograd.Function

forward(tensor, cp_group: torch.distributed.ProcessGroup, seq_dim=1)#
backward(grad_output)#
nemo_rl.distributed.model_utils.get_logprobs_from_vocab_parallel_logits(
vocab_parallel_logits: torch.distributed.tensor.DTensor,
input_ids: torch.Tensor | torch.distributed.tensor.DTensor,
seq_index: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
)#

Computes log probabilities from vocabulary-parallel logits.

This function takes logits that are sharded across the vocabulary dimension (tensor parallel) and computes the log probabilities for the given input IDs.

Parameters:
  • vocab_parallel_logits (DTensor) – Logits distributed across tensor parallel workers, with shape [batch_size, seq_len, vocab_size/tp_size].

  • input_ids (torch.Tensor | DTensor) – Input token IDs for which to compute log probabilities, with shape [batch_size, seq_len].

  • seq_index (Optional[torch.Tensor]) – Sequence index for the input IDs, with shape [sequence_length].

  • chunk_size (Optional[int]) – Sequence dimension chunk size for computing log probabilities.

Returns:

Log probabilities for the given input IDs.

Return type:

torch.Tensor

nemo_rl.distributed.model_utils.distributed_vocab_topk(
vocab_parallel_logits: torch.Tensor,
k: int,
tp_group: torch.distributed.ProcessGroup,
*,
vocab_start_index: int,
vocab_end_index: int,
chunk_size: Optional[int] = None,
) tuple[torch.Tensor, torch.Tensor]#

Compute global top-k over TP-sharded vocabulary logits.

Parameters:
  • vocab_parallel_logits – [B, S, V_local]

  • k – number of top tokens to select globally

  • tp_group – tensor-parallel process group

  • vocab_start_index – global vocab start for this rank (inclusive)

  • vocab_end_index – global vocab end for this rank (exclusive)

  • chunk_size – optional chunk along sequence dim to bound memory

Returns:

[B, S, k] topk_global_indices: [B, S, k] (global token ids)

Return type:

topk_vals

nemo_rl.distributed.model_utils.gather_logits_at_global_indices(
vocab_parallel_logits: torch.Tensor,
global_indices: torch.Tensor,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
cp_group: Optional[torch.distributed.ProcessGroup] = None,
*,
vocab_start_index: int,
vocab_end_index: int,
chunk_size: Optional[int] = None,
) torch.Tensor#

Gather student logits at given global token indices under TP+CP sharding.

Differentiable w.r.t. vocab_parallel_logits.

Parameters:
  • vocab_parallel_logits – [B, S_cp, V_local] where S_cp is CP sharded sequence length

  • global_indices – [B, S_full, k] where S_full is full sequence length

  • tp_group – Optional tensor-parallel process group. If None, treats logits as full-vocab (no TP) and skips TP all-reduce.

  • vocab_start_index – global vocab start for this rank (inclusive)

  • vocab_end_index – global vocab end for this rank (exclusive)

  • chunk_size – optional chunk along sequence dim to bound memory

  • cp_group – Optional context-parallel process group

Returns:

[B, S_full, k]

Return type:

gathered_logits

class nemo_rl.distributed.model_utils.ChunkedDistributedEntropy#

Bases: torch.autograd.Function

Compute H_all = sum_v p_v log p_v across TP with chunking over sequence.

Forward returns [B, S] tensor of global entropy; backward propagates through logits.

static forward(
ctx: Any,
vocab_parallel_logits: torch.Tensor,
chunk_size: int,
tp_group: torch.distributed.ProcessGroup,
inference_only: bool = False,
) torch.Tensor#
static backward(
ctx: Any,
*grad_outputs: torch.Tensor,
) tuple[torch.Tensor, None, None, None]#