Source code for nemo_rl.distributed.model_utils

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch


[docs] @torch.no_grad() def _compute_distributed_log_softmax( vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup ) -> torch.Tensor: """Compute a stable distributed log softmax across tensor parallel workers. Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265 Args: vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_length, vocab_size//TP] where TP is the tensor parallel size. group (torch.distributed.ProcessGroup): Process group for the all-reduce operations. Returns: torch.Tensor: Log softmax output with the same shape as input, but values represent log probabilities normalized across the full vocabulary dimension. """ logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True) torch.distributed.all_reduce( logits_max, op=torch.distributed.ReduceOp.MAX, group=group, ) # Subtract the maximum value. vocab_parallel_logits = vocab_parallel_logits - logits_max sum_exp_logits = vocab_parallel_logits.exp().sum(-1, keepdim=True).float() torch.distributed.all_reduce( sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=group, ) return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype)
[docs] class DistributedLogprob(torch.autograd.Function): """Custom autograd function for computing log probabilities in a distributed setting. Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 """
[docs] @staticmethod def forward( ctx, vocab_parallel_logits: torch.Tensor, target: torch.Tensor, vocab_start_index: int, vocab_end_index: int, group: torch.distributed.ProcessGroup, inference_only: bool = False, ): # Create a mask of valid vocab ids (1 means it needs to be masked). target_mask = (target < vocab_start_index) | (target >= vocab_end_index) masked_target = target - vocab_start_index masked_target[target_mask] = 0 log_softmax_output = _compute_distributed_log_softmax( vocab_parallel_logits, group=group ) log_probs = log_softmax_output.clone() softmax_output = log_softmax_output.exp_() log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) log_probs[target_mask] = 0.0 torch.distributed.all_reduce( log_probs, op=torch.distributed.ReduceOp.SUM, group=group, ) if not inference_only: # only save for backward when we have inference only=False ctx.save_for_backward(softmax_output, target_mask, masked_target) return log_probs
[docs] @staticmethod def backward( ctx, grad_output: torch.Tensor ) -> Tuple[torch.Tensor, None, None, None, None, None, None]: softmax, target_mask, masked_target = ctx.saved_tensors partition_vocab_size = softmax.size(-1) # 1 if it's the chosen log prob, 0 otherwise is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( masked_target, num_classes=partition_vocab_size ) grad_input = is_chosen.float().sub_(softmax) grad_input.mul_(grad_output.unsqueeze(dim=-1)) # if you add an argument to the forward method, then you must add a corresponding None here return grad_input, None, None, None, None, None, None
[docs] def from_parallel_logits_to_logprobs( vocab_parallel_logits: torch.Tensor, target: torch.Tensor, vocab_start_index: int, vocab_end_index: int, group: torch.distributed.ProcessGroup, inference_only: bool = False, ) -> torch.Tensor: """Get log probabilities from TP sharded vocab logits. Args: vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_len, vocab_size//TP] where TP is the tensor parallel size. target (torch.Tensor): Target token indices with shape [batch_size, seq_len]. NOTE: Must be the unmodified targets as this function will shift them internally. vocab_start_index (int): Starting vocabulary index for this worker's partition. vocab_end_index (int): Ending vocabulary index for this worker's partition. group (torch.distributed.ProcessGroup): Process group for distributed communication. inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. The sequence dimension is reduced by 1 due to the target shifting. Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 """ target = target.roll(shifts=-1, dims=-1) probs = DistributedLogprob.apply( vocab_parallel_logits, target, vocab_start_index, vocab_end_index, group, inference_only, ).contiguous() return probs[:, :-1]