nemo_rl.distributed.model_utils
#
Module Contents#
Classes#
Custom autograd function for computing log probabilities in a distributed setting. |
|
Functions#
Compute a stable distributed log softmax across tensor parallel workers. |
|
Get log probabilities from TP+CP sharded vocab logits. |
|
Get log probabilities from TP+CP sharded vocab logits. |
|
Get log probabilities from TP sharded vocab logits for packed sequences. |
|
Get tokens on this context parallelism rank. |
|
API#
- nemo_rl.distributed.model_utils._compute_distributed_log_softmax(
- vocab_parallel_logits: torch.Tensor,
- group: torch.distributed.ProcessGroup,
Compute a stable distributed log softmax across tensor parallel workers.
Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265
- Parameters:
vocab_parallel_logits (torch.Tensor) – Logits tensor with shape [batch_size, seq_length, vocab_size//TP] where TP is the tensor parallel size.
group (torch.distributed.ProcessGroup) – Process group for the all-reduce operations.
- Returns:
Log softmax output with the same shape as input, but values represent log probabilities normalized across the full vocabulary dimension.
- Return type:
torch.Tensor
- class nemo_rl.distributed.model_utils.DistributedLogprob(*args, **kwargs)[source]#
Bases:
torch.autograd.Function
Custom autograd function for computing log probabilities in a distributed setting.
Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286
Initialization
- nemo_rl.distributed.model_utils.dtensor_from_parallel_logits_to_logprobs(
- vocab_parallel_logits: torch.Tensor,
- target: torch.distributed.tensor.DTensor | torch.Tensor,
- vocab_start_index: int,
- vocab_end_index: int,
- tp_group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- seq_index: Optional[torch.Tensor] = None,
Get log probabilities from TP+CP sharded vocab logits.
- Parameters:
vocab_parallel_logits (orch.Tensor) – Logits distributed across tensor parallel workers, with shape [batch_size, seq_len, vocab_size/tp_size].
target (DTensor) – Target token indices with shape [batch_size, seq_len]. NOTE: Must be the unmodified targets as this function will shift them internally.
vocab_start_index (int) – Starting vocabulary index for this worker’s partition.
vocab_end_index (int) – Ending vocabulary index for this worker’s partition.
tp_group (torch.distributed.ProcessGroup) – Process group for distributed communication.
inference_only (bool, optional) – If True, tensors won’t be saved for backward pass. Defaults to False.
seq_index (Optional[torch.Tensor]) – Sequence index tensor with shape [seq_len]. It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension.
- Returns:
Log probabilities tensor with shape [batch_size, seq_len-1]. The sequence dimension is reduced by 1 due to the target shifting.
- Return type:
torch.Tensor
- nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs(
- vocab_parallel_logits: torch.Tensor,
- target: torch.Tensor,
- vocab_start_index: int,
- vocab_end_index: int,
- tp_group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- cp_group: Optional[torch.distributed.ProcessGroup] = None,
Get log probabilities from TP+CP sharded vocab logits.
- Parameters:
vocab_parallel_logits (torch.Tensor) – Logits tensor with shape [batch_size, seq_len // CP, vocab_size // TP] where TP is the tensor parallel size.
target (torch.Tensor) – Target token indices with shape [batch_size, seq_len]. NOTE: Must be the unmodified targets as this function will shift them internally.
vocab_start_index (int) – Starting vocabulary index for this worker’s partition.
vocab_end_index (int) – Ending vocabulary index for this worker’s partition.
tp_group (torch.distributed.ProcessGroup) – Process group for distributed communication.
inference_only (bool, optional) – If True, tensors won’t be saved for backward pass. Defaults to False.
cp_group (torch.distributed.ProcessGroup, optional) – Context parallelism process group. Defaults to None.
- Returns:
Log probabilities tensor with shape [batch_size, seq_len-1]. The sequence dimension is reduced by 1 due to the target shifting.
- Return type:
torch.Tensor
Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354
- nemo_rl.distributed.model_utils.from_parallel_logits_to_logprobs_packed_sequences(
- vocab_parallel_logits: torch.Tensor,
- target: torch.Tensor,
- cu_seqlens_padded: torch.Tensor,
- unpacked_seqlen: int,
- vocab_start_index: int,
- vocab_end_index: int,
- group: torch.distributed.ProcessGroup,
- inference_only: bool = False,
- cp_group: Optional[torch.distributed.ProcessGroup] = None,
Get log probabilities from TP sharded vocab logits for packed sequences.
- Parameters:
vocab_parallel_logits (torch.Tensor) – Packed logits tensor with shape [1, T // CP, vocab_size//TP] where T is the total number of tokens across all packed sequences.
target (torch.Tensor) – Packed target token indices with shape [1, T]. NOTE: Must be the unmodified targets as this function will shift them internally.
cu_seqlens (torch.Tensor) – Cumulative sequence lengths tensor with shape [batch_size + 1]. cu_seqlens[i] indicates the start position of sequence i in the packed format.
unpacked_seqlen (int) – The length of the unpacked sequence tensor.
vocab_start_index (int) – Starting vocabulary index for this worker’s partition.
vocab_end_index (int) – Ending vocabulary index for this worker’s partition.
group (torch.distributed.ProcessGroup) – Process group for distributed communication.
inference_only (bool, optional) – If True, tensors won’t be saved for backward pass. Defaults to False.
cp_group (torch.distributed.ProcessGroup, optional) – Context parallelism process group. Defaults to None.
- Returns:
Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. The total length is reduced by batch_size due to target shifting (one token per sequence).
- Return type:
torch.Tensor
- nemo_rl.distributed.model_utils._get_tokens_on_this_cp_rank(
- input_ids: torch.Tensor,
- cp_rank: int,
- cp_size: int,
- seq_dim: int = 1,
Get tokens on this context parallelism rank.
Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1.
- Parameters:
input_ids – Input token IDs [seq_length, ]
cp_rank – Context parallelism rank
cp_size – Context parallelism size
- Returns:
Tokens on this context parallelism rank [1, seq_length // cp_size]