nemo_rl.models.huggingface.common#

Module Contents#

Classes#

FlashAttentionKwargs

Dataclass to hold FlashAttention v2 kwargs.

ModelFlag

Enum that defines special flags for model-specific behaviors.

Functions#

is_gemma_model

group_and_cat_tensors

Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor.

pack_sequences

Packs sequences into rows where each row concatenates multiple sequences.

unpack_tensor

Unpacks a packed tensor into individual sequences padded to the same length.

get_flash_attention_kwargs

Returns kwargs required for FlashAttention v2 forward functions.

Data#

API#

nemo_rl.models.huggingface.common.Tensor#

β€˜TypeVar(…)’

class nemo_rl.models.huggingface.common.FlashAttentionKwargs[source]#

Dataclass to hold FlashAttention v2 kwargs.

cu_seqlens_q: nemo_rl.models.huggingface.common.Tensor#

None

cu_seqlens_k: nemo_rl.models.huggingface.common.Tensor#

None

max_seqlen_q: int#

None

max_seqlen_k: int#

None

class nemo_rl.models.huggingface.common.ModelFlag(*args, **kwds)[source]#

Bases: enum.Enum

Enum that defines special flags for model-specific behaviors.

This enum provides a way to identify models that require special handling or configuration in different parts of the NeMo RL codebase.

Flags: SKIP_DTENSOR_TIED_WEIGHTS_CHECK: Models that should skip the tied weights check for the DTensor Policy even without setting the NRL_SKIP_TIED_WEIGHT_CHECK flag. VLLM_LOAD_FORMAT_AUTO: Models that should use the β€œauto” load format when initializing VLLM.

Each flag has a matches method that determines if the flag applies to a given model_name.

Initialization

SKIP_DTENSOR_TIED_WEIGHTS_CHECK#

β€˜auto(…)’

VLLM_LOAD_FORMAT_AUTO#

β€˜auto(…)’

matches(model_name: str) bool[source]#
nemo_rl.models.huggingface.common.is_gemma_model(model_name: str) bool[source]#
nemo_rl.models.huggingface.common.group_and_cat_tensors(
tensors: list[torch.Tensor],
group_sizes: list[int],
padding_value: int = 0,
min_seq_len: int = 0,
) torch.Tensor[source]#

Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor.

Each group of 1D tensors is concatenated into a single 1D tensor, and all resulting group tensors are padded to the same length and stacked into a 2D tensor.

Parameters:
  • tensors – List of 1D tensors of varying lengths.

  • group_sizes – List of integers. Each integer specifies how many tensors to group.

  • padding_value – Integer used to pad shorter sequences.

  • min_seq_len – Minimum sequence length.

Returns:

A 2D tensor where each row is a padded concatenation of the grouped tensors.

.. rubric:: Example

tensors = [ … torch.tensor([1, 2]), … torch.tensor([3]), … torch.tensor([4, 5, 6]), … torch.tensor([7]) … ] group_sizes = [2, 2] group_and_cat_tensors(tensors, group_sizes, padding_value=-1) tensor([[ 1, 2, 3, -1, -1], [ 4, 5, 6, 7, -1]])

nemo_rl.models.huggingface.common.pack_sequences(
input_ids: torch.Tensor,
input_lengths: torch.Tensor,
packed_sequence_size: list[int],
padding_value: int = 0,
return_attention_mask: bool = True,
min_seq_len: int = 0,
) Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]][source]#

Packs sequences into rows where each row concatenates multiple sequences.

Useful for sequence packing in transformer models (e.g. for SFT training). Returns: packed input_ids, packed position_ids, and optional attention_mask.

Parameters:
  • input_ids (torch.Tensor) – Tensor of shape [num_sequences, max_seq_len]

  • input_lengths (torch.Tensor) – Tensor of shape [num_sequences], containing true lengths

  • packed_sequence_size (List[int]) – How many sequences to pack per row

  • padding_value (int) – Pad value for input_ids

  • return_attention_mask (bool) – Whether to return per-row causal attention mask

  • min_seq_len (int) – Minimum sequence length.

Returns:

input_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] position_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] attention_mask (Optional[torch.Tensor]): [batch_size, max_len, max_len] if requested

Return type:

Tuple

.. rubric:: Example

input_ids = torch.tensor([ … [1, 2, 0, 0], # len 2 … [3, 4, 5, 0], # len 3 … [6, 0, 0, 0], # len 1 … [7, 8, 9, 9], # len 4 … [8, 7, 0, 0], # len 2 … [6, 0, 0, 0], # len 1 … [5, 4, 3, 0], # len 3 … ]) input_lengths = torch.tensor([2, 3, 1, 4, 2, 1, 3]) packed_sequence_size = [3, 4] input_ids_packed, position_ids_packed, attention_mask = pack_sequences( … input_ids, input_lengths, packed_sequence_size, padding_value=-1, return_attention_mask=True … ) input_ids_packed tensor([ [ 1, 2, 3, 4, 5, 6, -1, -1, -1, -1], [ 7, 8, 9, 9, 8, 7, 6, 5, 4, 3] ]) position_ids_packed tensor([ [0, 1, 0, 1, 2, 0, 0, 0, 0, 0], [0, 1, 2, 3, 0, 1, 0, 0, 1, 2] ]) attention_mask[0] tensor([ [ True, True, False, False, False, False, False, False, False, False], [False, False, True, True, True, False, False, False, False, False], [False, False, False, False, False, True, False, False, False, False], [False, False, False, False, False, False, False, False, False, False], ]) attention_mask[1] tensor([ [ True, True, True, True, False, False, False, False, False, False], [False, False, False, False, True, True, True, False, False, False], [False, False, False, False, False, False, True, True, True, True], [False, False, False, False, False, False, False, True, True, True], ])

nemo_rl.models.huggingface.common.unpack_tensor(tensor, input_lengths)[source]#

Unpacks a packed tensor into individual sequences padded to the same length.

Parameters:
  • tensor (torch.Tensor) – Packed tensor of shape [batch_size, packed_seq_len].

  • packed_lengths (List[int]) – Original sequence lengths in the order they were packed.

Returns:

[num_sequences, max_seq_len], each row is one unpacked and padded sequence.

Return type:

torch.Tensor

.. rubric:: Example

packed_tensor = torch.tensor([ … [1, 2, 3, 4, 5, 6, -1, -1], … [7, 8, 9, 9, 8, 7, 6, -1] … ]) packed_lengths = [2, 3, 1, 4, 2] unpack_tensor(packed_tensor, packed_lengths) tensor([ [1, 2, 0, 0], [3, 4, 5, 0], [6, 0, 0, 0], [7, 8, 9, 9], [8, 7, 0, 0], ])

nemo_rl.models.huggingface.common.get_flash_attention_kwargs(
input_lengths: torch.Tensor,
) nemo_rl.models.huggingface.common.FlashAttentionKwargs[source]#

Returns kwargs required for FlashAttention v2 forward functions.

Parameters:

input_lengths (torch.Tensor) – [batch_size] containing lengths of each sequence

Returns:

 {
     "cu_seqlens_q": Tensor[int32],
     "cu_seqlens_k": Tensor[int32],
     "max_seqlen_q": int,
     "max_seqlen_k": int
 }

Return type:

Dict[str, torch.Tensor | int]