nemo_rl.distributed.model_utils#

Module Contents#

Classes#

DistributedLogprob

Custom autograd function for computing log probabilities in a distributed setting.

Functions#

_compute_distributed_log_softmax

Compute a stable distributed log softmax across tensor parallel workers.

from_parallel_logits_to_logprobs

Get log probabilities from TP sharded vocab logits.

API#

nemo_rl.distributed.model_utils._compute_distributed_log_softmax(
vocab_parallel_logits: torch.Tensor,
group: torch.distributed.ProcessGroup,
) torch.Tensor[source]#

Compute a stable distributed log softmax across tensor parallel workers.

Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265

Parameters:
  • vocab_parallel_logits (torch.Tensor) โ€“ Logits tensor with shape [batch_size, seq_length, vocab_size//TP] where TP is the tensor parallel size.

  • group (torch.distributed.ProcessGroup) โ€“ Process group for the all-reduce operations.

Returns:

Log softmax output with the same shape as input, but values represent log probabilities normalized across the full vocabulary dimension.

Return type:

torch.Tensor

class nemo_rl.distributed.model_utils.DistributedLogprob(*args, **kwargs)[source]#

Bases: torch.autograd.Function

Custom autograd function for computing log probabilities in a distributed setting.

Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286

Initialization

static forward(
ctx,
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
group: torch.distributed.ProcessGroup,
inference_only: bool = False,
)[source]#
static backward(
ctx,
grad_output: torch.Tensor,
) Tuple[torch.Tensor, None, None, None, None, None, None][source]#
nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
group: torch.distributed.ProcessGroup,
inference_only: bool = False,
) torch.Tensor[source]#

Get log probabilities from TP sharded vocab logits.

Parameters:
  • vocab_parallel_logits (torch.Tensor) โ€“ Logits tensor with shape [batch_size, seq_len, vocab_size//TP] where TP is the tensor parallel size.

  • target (torch.Tensor) โ€“ Target token indices with shape [batch_size, seq_len]. NOTE: Must be the unmodified targets as this function will shift them internally.

  • vocab_start_index (int) โ€“ Starting vocabulary index for this workerโ€™s partition.

  • vocab_end_index (int) โ€“ Ending vocabulary index for this workerโ€™s partition.

  • group (torch.distributed.ProcessGroup) โ€“ Process group for distributed communication.

  • inference_only (bool, optional) โ€“ If True, tensors wonโ€™t be saved for backward pass. Defaults to False.

Returns:

Log probabilities tensor with shape [batch_size, seq_len-1]. The sequence dimension is reduced by 1 due to the target shifting.

Return type:

torch.Tensor

Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354