nemo_automodel.components.models.common.packing#

Flash Attention packing support via monkey-patching.

When attn_implementation="flash_attention_2" and neat packing is enabled, the collater produces an indexed attention mask [B, S] where each position contains the 1-based document index (0 = padding). For example::

[1, 1, 2, 2, 2, 0]   # 2 tokens in doc 1, 3 in doc 2, 1 padding

To make HuggingFace’s flash attention path use flash_attn_varlen_func with per-document cu_seqlens, we monkey-patch two functions:

  1. transformers.modeling_flash_attention_utils._get_unpad_data — extracts per-document sequence lengths from the indexed mask and builds cu_seqlens.

  2. transformers.models.qwen3_vl.modeling_qwen3_vl.create_causal_mask — returns the 2D indexed mask as-is, bypassing 4D mask creation.

This is the same approach used by LlamaFactory.

Module Contents#

Functions#

get_seqlens_in_batch

Extract per-document sequence lengths from an indexed attention mask.

get_unpad_data

Prepare indices and cu_seqlens for flash_attn_varlen_func.

_passthrough_create_causal_mask

Replacement for create_causal_mask that passes through packed masks.

get_attn_implementation

Determine the attention backend from model config.

configure_packing

Apply monkey-patches for packed-sequence training with flash_attention_2.

Data#

API#

nemo_automodel.components.models.common.packing.logger#

‘getLogger(…)’

nemo_automodel.components.models.common.packing.get_seqlens_in_batch(attention_mask: torch.Tensor) torch.Tensor#

Extract per-document sequence lengths from an indexed attention mask.

Parameters:

attention_mask – [B, S] integer tensor where each position contains the 1-based document index (0 = padding).

Returns:

1D tensor of all individual document lengths across the batch.

Example::

>>> get_seqlens_in_batch(torch.tensor([[1, 1, 2, 2, 2, 0],
...                                    [1, 2, 2, 3, 3, 3]]))
tensor([2, 3, 1, 2, 3])
nemo_automodel.components.models.common.packing.get_unpad_data(
attention_mask: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor, int]#

Prepare indices and cu_seqlens for flash_attn_varlen_func.

This is a drop-in replacement for transformers.modeling_flash_attention_utils._get_unpad_data that handles indexed attention masks (values 1, 2, 3, …) instead of binary (0/1) masks. Each unique non-zero value is treated as a separate document, so flash_attn_varlen_func applies causal attention within each document without cross-document attention.

Returns:

Indices of non-padding tokens from the flattened sequence. cu_seqlens: Cumulative sequence lengths (starts from 0). max_seqlen_in_batch: Largest document length in the batch.

Return type:

indices

Example::

>>> get_unpad_data(torch.tensor([[1, 1, 2, 2, 2, 0],
...                              [1, 2, 2, 3, 3, 3]]))
(tensor([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]),
 tensor([ 0,  2,  5,  6,  8, 11], dtype=torch.int32),
 3)
nemo_automodel.components.models.common.packing._passthrough_create_causal_mask(
config=None,
input_embeds=None,
inputs_embeds=None,
attention_mask=None,
cache_position=None,
past_key_values=None,
position_ids=None,
**kwargs,
)#

Replacement for create_causal_mask that passes through packed masks.

FA2 handles masking internally, so always pass through. For non-FA2 backends, pass through packed masks but delegate normal 2D masks to HF.

nemo_automodel.components.models.common.packing.get_attn_implementation(cfg_model)#

Determine the attention backend from model config.

Custom models store it in backend.attn; HF models use attn_implementation.

nemo_automodel.components.models.common.packing._PACKING_PATCH_MODULES#

[‘transformers.models.qwen2.modeling_qwen2’, ‘transformers.models.qwen2_5_vl.modeling_qwen2_5_vl’, ‘…

nemo_automodel.components.models.common.packing.configure_packing(attn_implementation: str = 'sdpa') None#

Apply monkey-patches for packed-sequence training with flash_attention_2.

Only patches when attn_implementation == "flash_attention_2".

Parameters:

attn_implementation – The attention implementation used by the model.