# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
from torch.distributed.tensor import DTensor, distribute_tensor
[docs]
@torch.no_grad()
def _compute_distributed_log_softmax(
vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup
) -> torch.Tensor:
"""Compute a stable distributed log softmax across tensor parallel workers.
Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L265
Args:
vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_length, vocab_size//TP]
where TP is the tensor parallel size.
group (torch.distributed.ProcessGroup): Process group for the all-reduce operations.
Returns:
torch.Tensor: Log softmax output with the same shape as input, but values represent
log probabilities normalized across the full vocabulary dimension.
"""
logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True)
torch.distributed.all_reduce(
logits_max,
op=torch.distributed.ReduceOp.MAX,
group=group,
)
# Subtract the maximum value.
vocab_parallel_logits = vocab_parallel_logits - logits_max
sum_exp_logits = vocab_parallel_logits.exp().sum(-1, keepdim=True).float()
torch.distributed.all_reduce(
sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=group,
)
return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype)
[docs]
class DistributedLogprob(torch.autograd.Function):
"""Custom autograd function for computing log probabilities in a distributed setting.
Taken from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286
"""
[docs]
@staticmethod
def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class
ctx: Any,
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
group: torch.distributed.ProcessGroup,
inference_only: bool = False,
) -> torch.Tensor:
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target - vocab_start_index
masked_target[target_mask] = 0
log_softmax_output = _compute_distributed_log_softmax(
vocab_parallel_logits, group=group
)
log_probs = log_softmax_output.clone()
softmax_output = log_softmax_output.exp_()
log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1)
log_probs[target_mask] = 0.0
torch.distributed.all_reduce(
log_probs,
op=torch.distributed.ReduceOp.SUM,
group=group,
)
if not inference_only:
# only save for backward when we have inference only=False
ctx.save_for_backward(softmax_output, target_mask, masked_target)
return log_probs
[docs]
@staticmethod
def backward(
ctx: Any,
*grad_outputs: torch.Tensor,
) -> tuple[torch.Tensor, None, None, None, None, None, None]:
grad_output = grad_outputs[0]
softmax, target_mask, masked_target = ctx.saved_tensors
if softmax.ndim == 3:
B, S, V = softmax.shape
# skip `torch.nn.functional.one_hot`
row = (
torch.arange(B, device=softmax.device)
.view(-1, 1)
.expand(-1, S)
.reshape(-1)
)
col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1)
flat_idx = (row * S + col) * V
flat_chosen = flat_idx.masked_select(
~target_mask.reshape(-1)
) + masked_target.masked_select(~target_mask)
# `neg` is zero-copy
grad_input = softmax.neg()
grad_input = grad_input.mul_(grad_output.unsqueeze(-1))
grad_output_selected = grad_output.masked_select(~target_mask)
grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected)
else:
V = softmax.size(-1)
is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot(
masked_target, num_classes=V
)
grad_input = is_chosen.float().sub_(softmax)
grad_input.mul_(grad_output.unsqueeze(-1))
# if you add an argument to the forward method, then you must add a corresponding None here
return grad_input, None, None, None, None, None, None
[docs]
def dtensor_from_parallel_logits_to_logprobs(
vocab_parallel_logits: torch.Tensor,
target: DTensor | torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
tp_group: torch.distributed.ProcessGroup,
inference_only: bool = False,
seq_index: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Get log probabilities from TP+CP sharded vocab logits.
Args:
vocab_parallel_logits (orch.Tensor): Logits distributed across tensor parallel workers,
with shape [batch_size, seq_len, vocab_size/tp_size].
target (DTensor): Target token indices with shape [batch_size, seq_len].
NOTE: Must be the unmodified targets as this function will shift them internally.
vocab_start_index (int): Starting vocabulary index for this worker's partition.
vocab_end_index (int): Ending vocabulary index for this worker's partition.
tp_group (torch.distributed.ProcessGroup): Process group for distributed communication.
inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False.
seq_index (Optional[torch.Tensor]): Sequence index tensor with shape [seq_len].
It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension.
Returns:
torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1].
The sequence dimension is reduced by 1 due to the target shifting.
"""
cp_size = 1
if (
isinstance(target, DTensor)
and target.device_mesh.mesh_dim_names is not None
and "cp" in target.device_mesh.mesh_dim_names
):
cp_dim_index = target.device_mesh.mesh_dim_names.index("cp")
cp_size = target.device_mesh.shape[cp_dim_index]
if cp_size > 1:
assert seq_index is not None, "seq_index must be provided for cp sharded logits"
target_shape = torch.Size(target.shape)
cp_mesh = target.device_mesh
cp_placements = target.placements
_, sorted_indices = torch.sort(seq_index)
# Recover the original order of the target
target = target.full_tensor()[:, sorted_indices]
target = target.roll(shifts=-1, dims=-1)[:, seq_index]
# Reshard
target = distribute_tensor(target, cp_mesh, cp_placements)
target = target.to_local()
else:
target = target.roll(shifts=-1, dims=-1)
probs: torch.Tensor = DistributedLogprob.apply( # type: ignore
vocab_parallel_logits,
target,
vocab_start_index,
vocab_end_index,
tp_group,
inference_only,
).contiguous()
if cp_size > 1:
# probs is sharded on the sequence dimension.
# Get full sequence tensor, vocab dim has been reduced already.
probs_dtensor = DTensor.from_local(probs, cp_mesh, cp_placements)
probs = probs_dtensor.full_tensor()[:, sorted_indices]
assert probs.shape == target_shape
return probs[:, :-1]
[docs]
def from_parallel_logits_to_logprobs(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
tp_group: torch.distributed.ProcessGroup,
inference_only: bool = False,
cp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> torch.Tensor:
"""Get log probabilities from TP+CP sharded vocab logits.
Args:
vocab_parallel_logits (torch.Tensor): Logits tensor with shape [batch_size, seq_len // CP, vocab_size // TP]
where TP is the tensor parallel size.
target (torch.Tensor): Target token indices with shape [batch_size, seq_len].
NOTE: Must be the unmodified targets as this function will shift them internally.
vocab_start_index (int): Starting vocabulary index for this worker's partition.
vocab_end_index (int): Ending vocabulary index for this worker's partition.
tp_group (torch.distributed.ProcessGroup): Process group for distributed communication.
inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False.
cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None.
Returns:
torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1].
The sequence dimension is reduced by 1 due to the target shifting.
Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354
"""
target = target.roll(shifts=-1, dims=-1)
cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group)
pad_len = 0
# if cp_size > 1:
# Pad the targets to local size * cp_size
pad_len = vocab_parallel_logits.shape[1] * cp_size - target.shape[1]
if pad_len > 0:
target = torch.nn.functional.pad(target, (0, pad_len), value=0)
# Shard the targets by context parallelism
cp_rank = torch.distributed.get_rank(cp_group)
target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1)
probs: torch.Tensor = DistributedLogprob.apply( # type: ignore
vocab_parallel_logits,
target,
vocab_start_index,
vocab_end_index,
tp_group,
inference_only,
).contiguous()
if cp_size > 1:
# we need to gather the logits by context parallelism
probs = allgather_cp_sharded_tensor(
probs, cp_group, seq_dim=1
) # , unpadded_seqlen=target.shape[1])
if pad_len > 0:
probs = probs[:, :-pad_len]
return probs[:, :-1]
[docs]
def from_parallel_logits_to_logprobs_packed_sequences(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
cu_seqlens_padded: torch.Tensor,
unpacked_seqlen: int,
vocab_start_index: int,
vocab_end_index: int,
group: torch.distributed.ProcessGroup,
inference_only: bool = False,
cp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> torch.Tensor:
"""Get log probabilities from TP sharded vocab logits for packed sequences.
Args:
vocab_parallel_logits (torch.Tensor): Packed logits tensor with shape [1, T // CP, vocab_size//TP]
where T is the total number of tokens across all packed sequences.
target (torch.Tensor): Packed target token indices with shape [1, T].
NOTE: Must be the unmodified targets as this function will shift them internally.
cu_seqlens (torch.Tensor): Cumulative sequence lengths tensor with shape [batch_size + 1].
cu_seqlens[i] indicates the start position of sequence i in the packed format.
unpacked_seqlen (int): The length of the unpacked sequence tensor.
vocab_start_index (int): Starting vocabulary index for this worker's partition.
vocab_end_index (int): Ending vocabulary index for this worker's partition.
group (torch.distributed.ProcessGroup): Process group for distributed communication.
inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False.
cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None.
Returns:
torch.Tensor: Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1].
The total length is reduced by batch_size due to target shifting (one token per sequence).
"""
# Remove batch dimension to work with [T, vocab_size] and [T]
vocab_parallel_logits = vocab_parallel_logits.squeeze(0)
target = target.squeeze(0)
batch_size = cu_seqlens_padded.shape[0] - 1
cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group)
cp_rank = 0 if cp_group is None else torch.distributed.get_rank(cp_group)
# Roll each sequence individually
rolled_targets = torch.zeros(
target.shape[0] // cp_size, dtype=target.dtype, device=target.device
)
for i in range(batch_size):
start_idx = cu_seqlens_padded[i].item()
end_idx = cu_seqlens_padded[i + 1].item()
# Get the sequence targets and roll by -1
seq_targets = target[start_idx:end_idx]
rolled_seq_targets = seq_targets.roll(shifts=-1, dims=0)
rolled_targets[start_idx // cp_size : end_idx // cp_size] = (
_get_tokens_on_this_cp_rank(rolled_seq_targets, cp_rank, cp_size, seq_dim=0)
)
# Add batch dimension back for DistributedLogprob
rolled_targets = rolled_targets.unsqueeze(0)
vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0)
# Apply distributed log probability computation
probs: torch.Tensor = DistributedLogprob.apply( # type: ignore
vocab_parallel_logits,
rolled_targets,
vocab_start_index,
vocab_end_index,
group,
inference_only,
).contiguous()
# Remove batch dimension for filtering
probs = probs.squeeze(0)
# Ensure probs is 1D after squeezing
if probs.dim() != 1:
raise ValueError(
f"Expected probs to be 1D after squeezing, but got shape {probs.shape}. "
f"Original shape before squeeze: {probs.unsqueeze(0).shape}"
)
if cp_size > 1:
# per-sequence cp_allgather
final_probs = torch.zeros(probs.shape[0] * cp_size, device=probs.device)
for i in range(batch_size):
start_idx = cu_seqlens_padded[i].item()
end_idx = cu_seqlens_padded[i + 1].item()
final_probs[start_idx:end_idx] = allgather_cp_sharded_tensor(
probs[start_idx // cp_size : end_idx // cp_size], cp_group, seq_dim=0
)
probs = final_probs
out_logprobs = torch.zeros(
(batch_size, unpacked_seqlen - 1), dtype=probs.dtype, device=probs.device
)
# Filter out the last token of each sequence
for i in range(batch_size):
start_idx = cu_seqlens_padded[i].item()
end_idx = cu_seqlens_padded[i + 1].item()
# Exclude the last position (which has the rolled target from position 0)
if end_idx - start_idx > 0:
seq_probs = probs[start_idx : end_idx - 1]
# Ensure seq_probs is 1D
if seq_probs.dim() > 1:
seq_probs = seq_probs.squeeze()
# Ensure we don't exceed the unpacked sequence length
seq_len = min(seq_probs.shape[0], unpacked_seqlen - 1)
if seq_len > 0:
out_logprobs[i, :seq_len] = seq_probs[:seq_len]
return out_logprobs
[docs]
def _get_tokens_on_this_cp_rank(
input_ids: torch.Tensor,
cp_rank: int,
cp_size: int,
seq_dim: int = 1,
) -> torch.Tensor:
"""Get tokens on this context parallelism rank.
Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1.
Args:
input_ids: Input token IDs [seq_length, ]
cp_rank: Context parallelism rank
cp_size: Context parallelism size
Returns:
Tokens on this context parallelism rank [1, seq_length // cp_size]
"""
if cp_size == 1:
return input_ids
# load balance for causal attention
shard_size = input_ids.shape[seq_dim] // (cp_size * 2)
shard_inds = (cp_rank, (cp_size * 2) - cp_rank - 1)
# Create slices for each dimension
slices = [slice(None)] * input_ids.dim()
ids_chunks = []
for ind in shard_inds:
slices[seq_dim] = slice(ind * shard_size, (ind + 1) * shard_size)
ids_chunks.append(input_ids[slices])
ids = torch.cat(ids_chunks, dim=seq_dim)
return ids
[docs]
def allgather_cp_sharded_tensor(
tensor, cp_group, seq_dim=1
): # , unpadded_seqlen=None):
return AllGatherCPTensor.apply(tensor, cp_group, seq_dim) # , unpadded_seqlen)
[docs]
class AllGatherCPTensor(torch.autograd.Function):
[docs]
def forward(
ctx, tensor, cp_group: torch.distributed.ProcessGroup, seq_dim=1
): # , unpadded_seqlen: Optional[int] = None):
cp_size = torch.distributed.get_world_size(cp_group)
cp_rank_chunks = []
for _ in range(cp_size):
cp_rank_chunks.append(torch.empty_like(tensor))
torch.distributed.all_gather(
tensor_list=cp_rank_chunks, tensor=tensor, group=cp_group
)
# undo the CP load balancing chunking
tensor_chunks = []
for logit_chunk in cp_rank_chunks:
tensor_chunks.extend(torch.chunk(logit_chunk, chunks=2, dim=seq_dim))
chunk_indices = []
for cp_rank in range(cp_size):
chunk_indices.append(cp_rank)
chunk_indices.append(2 * cp_size - cp_rank - 1)
chunks_and_indices = list(zip(tensor_chunks, chunk_indices))
chunks_and_indices = sorted(chunks_and_indices, key=lambda tup: tup[1])
ret_tensor = [chunk for chunk, _ in chunks_and_indices]
ret_tensor = torch.cat(ret_tensor, dim=seq_dim)
ctx.seq_dim = seq_dim
ctx.cp_group = cp_group
# ctx.unpadded_seqlen = unpadded_seqlen
return ret_tensor
[docs]
def backward(ctx, grad_output):
cp_size = torch.distributed.get_world_size(ctx.cp_group)
cp_rank = torch.distributed.get_rank(ctx.cp_group)
torch.distributed.all_reduce(grad_output, group=ctx.cp_group)
# chunk the seqdim in 2*cp chunks, and select with a CP load balanced indexing
seq_dim = ctx.seq_dim
# if ctx.unpadded_seqlen is not None:
# # Zero out grad_output along the seq_dim after unpadded_seqlen
# slicer = [slice(None)] * grad_output.dim()
# slicer[seq_dim] = slice(ctx.unpadded_seqlen, None)
# grad_output[tuple(slicer)] = 0
grad_output = grad_output.view(
*grad_output.shape[0:seq_dim],
2 * cp_size,
grad_output.shape[seq_dim] // (2 * cp_size),
*grad_output.shape[(seq_dim + 1) :],
)
index = torch.tensor(
[cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True
).cuda(non_blocking=True)
grad_input = grad_output.index_select(seq_dim, index)
grad_input = grad_input.view(
*grad_input.shape[0:seq_dim], -1, *grad_input.shape[(seq_dim + 2) :]
)
return grad_input, None, None # , None