# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from enum import Enum, auto
from typing import Optional, Tuple, TypeVar
import torch
from transformers import AutoConfig
Tensor = TypeVar("Tensor", bound=torch.Tensor)
[docs]
@dataclass
class FlashAttentionKwargs:
"""Dataclass to hold FlashAttention v2 kwargs."""
cu_seqlens_q: Tensor
cu_seqlens_k: Tensor
max_seqlen_q: int
max_seqlen_k: int
[docs]
class ModelFlag(Enum):
"""Enum that defines special flags for model-specific behaviors.
This enum provides a way to identify models that require special handling or
configuration in different parts of the NeMo RL codebase.
Flags:
SKIP_DTENSOR_TIED_WEIGHTS_CHECK: Models that should skip the tied weights check
for the DTensor Policy even without setting the
NRL_SKIP_TIED_WEIGHT_CHECK flag.
VLLM_LOAD_FORMAT_AUTO: Models that should use the "auto" load format when initializing
VLLM.
Each flag has a `matches` method that determines if the flag applies to a given model_name.
"""
SKIP_DTENSOR_TIED_WEIGHTS_CHECK = auto()
VLLM_LOAD_FORMAT_AUTO = auto()
[docs]
def matches(self, model_name: str) -> bool:
match self:
case ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK:
return is_gemma_model(model_name)
case ModelFlag.VLLM_LOAD_FORMAT_AUTO:
return is_gemma_model(model_name)
case _:
raise ValueError(f"Unknown ModelFlag: {self}")
[docs]
def is_gemma_model(model_name: str) -> bool:
hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
return hasattr(hf_config, "model_type") and hf_config.model_type in [
"gemma2",
"gemma3",
"gemma3_text",
]
[docs]
def group_and_cat_tensors(
tensors: list[torch.Tensor],
group_sizes: list[int],
padding_value: int = 0,
min_seq_len: int = 0,
) -> torch.Tensor:
"""Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor.
Each group of 1D tensors is concatenated into a single 1D tensor, and all resulting
group tensors are padded to the same length and stacked into a 2D tensor.
Args:
tensors: List of 1D tensors of varying lengths.
group_sizes: List of integers. Each integer specifies how many tensors to group.
padding_value: Integer used to pad shorter sequences.
min_seq_len: Minimum sequence length.
Returns:
A 2D tensor where each row is a padded concatenation of the grouped tensors.
Example:
>>> tensors = [
... torch.tensor([1, 2]),
... torch.tensor([3]),
... torch.tensor([4, 5, 6]),
... torch.tensor([7])
... ]
>>> group_sizes = [2, 2]
>>> group_and_cat_tensors(tensors, group_sizes, padding_value=-1)
tensor([[ 1, 2, 3, -1, -1],
[ 4, 5, 6, 7, -1]])
"""
grouped = []
index = 0
for size in group_sizes:
group = tensors[index : index + size]
concat = torch.cat(group, dim=0)
grouped.append(concat)
index += size
# Compute the maximum length for padding
max_len = max(t.size(0) for t in grouped)
max_len = max(max_len, min_seq_len)
# Pad each tensor to max_len
padded = torch.stack(
[
torch.nn.functional.pad(t, (0, max_len - t.size(0)), value=padding_value)
for t in grouped
]
)
return padded
[docs]
def pack_sequences(
input_ids: torch.Tensor,
input_lengths: torch.Tensor,
packed_sequence_size: list[int],
padding_value: int = 0,
return_attention_mask: bool = True,
min_seq_len: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Packs sequences into rows where each row concatenates multiple sequences.
Useful for sequence packing in transformer models (e.g. for SFT training). Returns:
packed input_ids, packed position_ids, and optional attention_mask.
Args:
input_ids (torch.Tensor): Tensor of shape [num_sequences, max_seq_len]
input_lengths (torch.Tensor): Tensor of shape [num_sequences], containing true lengths
packed_sequence_size (List[int]): How many sequences to pack per row
padding_value (int): Pad value for input_ids
return_attention_mask (bool): Whether to return per-row causal attention mask
min_seq_len (int): Minimum sequence length.
Returns:
Tuple:
input_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len]
position_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len]
attention_mask (Optional[torch.Tensor]): [batch_size, max_len, max_len] if requested
Example:
>>> input_ids = torch.tensor([
... [1, 2, 0, 0], # len 2
... [3, 4, 5, 0], # len 3
... [6, 0, 0, 0], # len 1
... [7, 8, 9, 9], # len 4
... [8, 7, 0, 0], # len 2
... [6, 0, 0, 0], # len 1
... [5, 4, 3, 0], # len 3
... ])
>>> input_lengths = torch.tensor([2, 3, 1, 4, 2, 1, 3])
>>> packed_sequence_size = [3, 4]
>>> input_ids_packed, position_ids_packed, attention_mask = pack_sequences(
... input_ids, input_lengths, packed_sequence_size, padding_value=-1, return_attention_mask=True
... )
>>> input_ids_packed
tensor([
[ 1, 2, 3, 4, 5, 6, -1, -1, -1, -1],
[ 7, 8, 9, 9, 8, 7, 6, 5, 4, 3]
])
>>> position_ids_packed
tensor([
[0, 1, 0, 1, 2, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 0, 1, 0, 0, 1, 2]
])
>>> attention_mask[0]
tensor([
[ True, True, False, False, False, False, False, False, False, False],
[False, False, True, True, True, False, False, False, False, False],
[False, False, False, False, False, True, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False],
])
>>> attention_mask[1]
tensor([
[ True, True, True, True, False, False, False, False, False, False],
[False, False, False, False, True, True, True, False, False, False],
[False, False, False, False, False, False, True, True, True, True],
[False, False, False, False, False, False, False, True, True, True],
])
"""
flat_input_ids = []
position_ids = []
flat_lengths = input_lengths.tolist()
for i, seq_len in enumerate(flat_lengths):
flat_input_ids.append(input_ids[i, :seq_len])
position_ids.append(
torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
)
# Group and pad
input_ids_packed = group_and_cat_tensors(
flat_input_ids, packed_sequence_size, padding_value, min_seq_len=min_seq_len
)
position_ids_packed = group_and_cat_tensors(
position_ids, packed_sequence_size, padding_value=0, min_seq_len=min_seq_len
)
# Compute max length
batch_size, max_seq_len = input_ids_packed.shape
attention_mask = None
if return_attention_mask:
attention_mask = torch.zeros(
(batch_size, max_seq_len, max_seq_len),
dtype=torch.bool,
device=input_ids.device,
)
index = 0
for i, group_size in enumerate(packed_sequence_size):
group_lengths = flat_lengths[index : index + group_size]
total_len = sum(group_lengths)
attention_mask[i, :total_len, :total_len] = torch.tril(
torch.ones(
(total_len, total_len), dtype=torch.bool, device=input_ids.device
)
)
index += group_size
return input_ids_packed, position_ids_packed, attention_mask
# TODO(ahmadki): the function doesn't actually handle returning 2D tensors because none of the backends support this.
# but we should support this anyways
[docs]
def unpack_tensor(tensor, input_lengths):
"""Unpacks a packed tensor into individual sequences padded to the same length.
Args:
tensor (torch.Tensor): Packed tensor of shape [batch_size, packed_seq_len].
packed_lengths (List[int]): Original sequence lengths in the order they were packed.
Returns:
torch.Tensor: [num_sequences, max_seq_len], each row is one unpacked and padded sequence.
Example:
>>> packed_tensor = torch.tensor([
... [1, 2, 3, 4, 5, 6, -1, -1],
... [7, 8, 9, 9, 8, 7, 6, -1]
... ])
>>> packed_lengths = [2, 3, 1, 4, 2]
>>> unpack_tensor(packed_tensor, packed_lengths)
tensor([
[1, 2, 0, 0],
[3, 4, 5, 0],
[6, 0, 0, 0],
[7, 8, 9, 9],
[8, 7, 0, 0],
])
"""
packed_seqlen = tensor.shape[1]
splitsizes = input_lengths.tolist()
splitsizes.append(packed_seqlen - sum(splitsizes))
tensor_split = torch.split(tensor, tuple(splitsizes), dim=1)
max_len = max(input_lengths.tolist()) # max sequence length in the batch
tensor_stacked = []
for t in tensor_split[0:-1]:
padding_needed = max_len - t.shape[1]
tensor_stacked.append(
torch.nn.functional.pad(
t, (0, 0, 0, padding_needed), mode="constant", value=0.0
)
)
return torch.cat(tensor_stacked, dim=0)
[docs]
def get_flash_attention_kwargs(input_lengths: torch.Tensor) -> FlashAttentionKwargs:
"""Returns kwargs required for FlashAttention v2 forward functions.
Args:
input_lengths (torch.Tensor): [batch_size] containing lengths of each sequence
Returns:
Dict[str, torch.Tensor | int]:
{
"cu_seqlens_q": Tensor[int32],
"cu_seqlens_k": Tensor[int32],
"max_seqlen_q": int,
"max_seqlen_k": int
}
"""
input_lengths_int32 = input_lengths.to(torch.int32)
cu_seqlens = torch.nn.functional.pad(
input_lengths_int32.cumsum(dim=0), (1, 0)
) # prepend 0
max_len = input_lengths.max().item()
return FlashAttentionKwargs(
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens.clone(), # same for self-attention
max_seqlen_q=max_len,
max_seqlen_k=max_len,
)