nemo_automodel.components.models.common.packing

View as Markdown

Flash Attention packing support via monkey-patching.

When attn_implementation="flash_attention_2" and neat packing is enabled, the collater produces an indexed attention mask [B, S] where each position contains the 1-based document index (0 = padding). For example::

[1, 1, 2, 2, 2, 0] # 2 tokens in doc 1, 3 in doc 2, 1 padding

To make HuggingFace’s flash attention path use flash_attn_varlen_func with per-document cu_seqlens, we monkey-patch two functions:

  1. transformers.modeling_flash_attention_utils._get_unpad_data — extracts per-document sequence lengths from the indexed mask and builds cu_seqlens.
  2. transformers.models.qwen3_vl.modeling_qwen3_vl.create_causal_mask — returns the 2D indexed mask as-is, bypassing 4D mask creation.

This is the same approach used by LlamaFactory.

Module Contents

Functions

NameDescription
_passthrough_create_causal_maskReplacement for create_causal_mask that passes through packed masks.
configure_packingApply monkey-patches for packed-sequence training with flash_attention_2.
get_attn_implementationDetermine the attention backend from model config.
get_seqlens_in_batchExtract per-document sequence lengths from an indexed attention mask.
get_unpad_dataPrepare indices and cu_seqlens for flash_attn_varlen_func.
is_indexed_packed_maskReturn True iff attention_mask is an Automodel-style indexed packing mask.

Data

_PACKING_PATCH_MODULES

logger

API

nemo_automodel.components.models.common.packing._passthrough_create_causal_mask(
config = None,
input_embeds = None,
inputs_embeds = None,
attention_mask = None,
cache_position = None,
past_key_values = None,
position_ids = None,
kwargs = {}
)

Replacement for create_causal_mask that passes through packed masks.

FA2 handles masking internally, so always pass through. For non-FA2 backends, pass through packed masks but delegate normal 2D masks to HF.

nemo_automodel.components.models.common.packing.configure_packing(
attn_implementation: str = 'sdpa'
) -> None

Apply monkey-patches for packed-sequence training with flash_attention_2.

Only patches when attn_implementation == "flash_attention_2".

Parameters:

attn_implementation
strDefaults to 'sdpa'

The attention implementation used by the model.

nemo_automodel.components.models.common.packing.get_attn_implementation(
cfg_model
)

Determine the attention backend from model config.

Custom models store it in backend.attn; HF models use attn_implementation.

nemo_automodel.components.models.common.packing.get_seqlens_in_batch(
attention_mask: torch.Tensor
) -> torch.Tensor

Extract per-document sequence lengths from an indexed attention mask.

Example::

>>> get_seqlens_in_batch(torch.tensor([[1, 1, 2, 2, 2, 0], … [1, 2, 2, 3, 3, 3]])) tensor([2, 3, 1, 2, 3])

Parameters:

attention_mask
torch.Tensor

[B, S] integer tensor where each position contains the 1-based document index (0 = padding).

Returns: torch.Tensor

1D tensor of all individual document lengths across the batch.

nemo_automodel.components.models.common.packing.get_unpad_data(
attention_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, int]

Prepare indices and cu_seqlens for flash_attn_varlen_func.

This is a drop-in replacement for transformers.modeling_flash_attention_utils._get_unpad_data that handles indexed attention masks (values 1, 2, 3, …) instead of binary (0/1) masks. Each unique non-zero value is treated as a separate document, so flash_attn_varlen_func applies causal attention within each document without cross-document attention.

Example::

>>> get_unpad_data(torch.tensor([[1, 1, 2, 2, 2, 0], … [1, 2, 2, 3, 3, 3]])) (tensor([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]), tensor([ 0, 2, 5, 6, 8, 11], dtype=torch.int32), 3)

Returns: torch.Tensor

Indices of non-padding tokens from the flattened sequence.

nemo_automodel.components.models.common.packing.is_indexed_packed_mask(
attention_mask: torch.Tensor | None
) -> bool

Return True iff attention_mask is an Automodel-style indexed packing mask.

The Automodel neat_packed_vlm_collater (and the LLM equivalent) encode packed-sample boundaries by marking document i (1-based) with the integer i and using 0 for padding (e.g. [1, 1, 1, 2, 2, 3, 3, 0, 0]). Any value greater than 1 is therefore a sufficient signal that two or more documents are packed into the same row. A standard 0/1 attention mask never has values > 1.

nemo_automodel.components.models.common.packing._PACKING_PATCH_MODULES = ['transformers.models.qwen2.modeling_qwen2', 'transformers.models.qwen2_5_vl.mod...
nemo_automodel.components.models.common.packing.logger = logging.getLogger(__name__)