nemo_rl.models.huggingface.common
#
Module Contents#
Classes#
Dataclass to hold FlashAttention v2 kwargs. |
|
Enum that defines special flags for model-specific behaviors. |
Functions#
Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor. |
|
Packs sequences into rows where each row concatenates multiple sequences. |
|
Unpacks a packed tensor into individual sequences padded to the same length. |
|
Returns kwargs required for FlashAttention v2 forward functions. |
Data#
API#
- nemo_rl.models.huggingface.common.Tensor#
βTypeVar(β¦)β
- class nemo_rl.models.huggingface.common.FlashAttentionKwargs[source]#
Dataclass to hold FlashAttention v2 kwargs.
- cu_seqlens_q: nemo_rl.models.huggingface.common.Tensor#
None
- cu_seqlens_k: nemo_rl.models.huggingface.common.Tensor#
None
- max_seqlen_q: int#
None
- max_seqlen_k: int#
None
- class nemo_rl.models.huggingface.common.ModelFlag(*args, **kwds)[source]#
Bases:
enum.Enum
Enum that defines special flags for model-specific behaviors.
This enum provides a way to identify models that require special handling or configuration in different parts of the NeMo RL codebase.
Flags: SKIP_DTENSOR_TIED_WEIGHTS_CHECK: Models that should skip the tied weights check for the DTensor Policy even without setting the NRL_SKIP_TIED_WEIGHT_CHECK flag. VLLM_LOAD_FORMAT_AUTO: Models that should use the βautoβ load format when initializing VLLM.
Each flag has a
matches
method that determines if the flag applies to a given model_name.Initialization
- SKIP_DTENSOR_TIED_WEIGHTS_CHECK#
βauto(β¦)β
- VLLM_LOAD_FORMAT_AUTO#
βauto(β¦)β
- nemo_rl.models.huggingface.common.group_and_cat_tensors(
- tensors: list[torch.Tensor],
- group_sizes: list[int],
- padding_value: int = 0,
- min_seq_len: int = 0,
Groups and concatenates tensors according to group_sizes, then pads them to form a 2D tensor.
Each group of 1D tensors is concatenated into a single 1D tensor, and all resulting group tensors are padded to the same length and stacked into a 2D tensor.
- Parameters:
tensors β List of 1D tensors of varying lengths.
group_sizes β List of integers. Each integer specifies how many tensors to group.
padding_value β Integer used to pad shorter sequences.
min_seq_len β Minimum sequence length.
- Returns:
A 2D tensor where each row is a padded concatenation of the grouped tensors.
.. rubric:: Example
tensors = [ β¦ torch.tensor([1, 2]), β¦ torch.tensor([3]), β¦ torch.tensor([4, 5, 6]), β¦ torch.tensor([7]) β¦ ] group_sizes = [2, 2] group_and_cat_tensors(tensors, group_sizes, padding_value=-1) tensor([[ 1, 2, 3, -1, -1], [ 4, 5, 6, 7, -1]])
- nemo_rl.models.huggingface.common.pack_sequences(
- input_ids: torch.Tensor,
- input_lengths: torch.Tensor,
- packed_sequence_size: list[int],
- padding_value: int = 0,
- return_attention_mask: bool = True,
- min_seq_len: int = 0,
Packs sequences into rows where each row concatenates multiple sequences.
Useful for sequence packing in transformer models (e.g. for SFT training). Returns: packed input_ids, packed position_ids, and optional attention_mask.
- Parameters:
input_ids (torch.Tensor) β Tensor of shape [num_sequences, max_seq_len]
input_lengths (torch.Tensor) β Tensor of shape [num_sequences], containing true lengths
packed_sequence_size (List[int]) β How many sequences to pack per row
padding_value (int) β Pad value for input_ids
return_attention_mask (bool) β Whether to return per-row causal attention mask
min_seq_len (int) β Minimum sequence length.
- Returns:
input_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] position_ids_packed (torch.Tensor): [batch_size, max_packed_seq_len] attention_mask (Optional[torch.Tensor]): [batch_size, max_len, max_len] if requested
- Return type:
Tuple
.. rubric:: Example
input_ids = torch.tensor([ β¦ [1, 2, 0, 0], # len 2 β¦ [3, 4, 5, 0], # len 3 β¦ [6, 0, 0, 0], # len 1 β¦ [7, 8, 9, 9], # len 4 β¦ [8, 7, 0, 0], # len 2 β¦ [6, 0, 0, 0], # len 1 β¦ [5, 4, 3, 0], # len 3 β¦ ]) input_lengths = torch.tensor([2, 3, 1, 4, 2, 1, 3]) packed_sequence_size = [3, 4] input_ids_packed, position_ids_packed, attention_mask = pack_sequences( β¦ input_ids, input_lengths, packed_sequence_size, padding_value=-1, return_attention_mask=True β¦ ) input_ids_packed tensor([ [ 1, 2, 3, 4, 5, 6, -1, -1, -1, -1], [ 7, 8, 9, 9, 8, 7, 6, 5, 4, 3] ]) position_ids_packed tensor([ [0, 1, 0, 1, 2, 0, 0, 0, 0, 0], [0, 1, 2, 3, 0, 1, 0, 0, 1, 2] ]) attention_mask[0] tensor([ [ True, True, False, False, False, False, False, False, False, False], [False, False, True, True, True, False, False, False, False, False], [False, False, False, False, False, True, False, False, False, False], [False, False, False, False, False, False, False, False, False, False], ]) attention_mask[1] tensor([ [ True, True, True, True, False, False, False, False, False, False], [False, False, False, False, True, True, True, False, False, False], [False, False, False, False, False, False, True, True, True, True], [False, False, False, False, False, False, False, True, True, True], ])
- nemo_rl.models.huggingface.common.unpack_tensor(tensor, input_lengths)[source]#
Unpacks a packed tensor into individual sequences padded to the same length.
- Parameters:
tensor (torch.Tensor) β Packed tensor of shape [batch_size, packed_seq_len].
packed_lengths (List[int]) β Original sequence lengths in the order they were packed.
- Returns:
[num_sequences, max_seq_len], each row is one unpacked and padded sequence.
- Return type:
torch.Tensor
.. rubric:: Example
packed_tensor = torch.tensor([ β¦ [1, 2, 3, 4, 5, 6, -1, -1], β¦ [7, 8, 9, 9, 8, 7, 6, -1] β¦ ]) packed_lengths = [2, 3, 1, 4, 2] unpack_tensor(packed_tensor, packed_lengths) tensor([ [1, 2, 0, 0], [3, 4, 5, 0], [6, 0, 0, 0], [7, 8, 9, 9], [8, 7, 0, 0], ])
- nemo_rl.models.huggingface.common.get_flash_attention_kwargs(
- input_lengths: torch.Tensor,
Returns kwargs required for FlashAttention v2 forward functions.
- Parameters:
input_lengths (torch.Tensor) β [batch_size] containing lengths of each sequence
- Returns:
{ "cu_seqlens_q": Tensor[int32], "cu_seqlens_k": Tensor[int32], "max_seqlen_q": int, "max_seqlen_k": int }
- Return type:
Dict[str, torch.Tensor | int]