nemo_automodel.components.attention.flex_attention

View as Markdown

Module Contents

Classes

NameDescription
FlexAttentionFlexAttention module that uses torch.nn.attention.flex_attention.

Data

FLEX_ATTN_MASK_T

API

class nemo_automodel.components.attention.flex_attention.FlexAttention()

Bases: Module

FlexAttention module that uses torch.nn.attention.flex_attention.

This module is a wrapper around torch.nn.attention.flex_attention. This module implements certain common attention types, such as causal and block_causal.

Parameters:

attn_mask_type
str

The type of attention mask. Currently, we support “causal” and “block_causal”. “causal” means the lower triangle of the attention matrix is masked. “block_causal” means the attention matrix is divided into blocks, where block boundary is defined by EOS token, and the lower triangle of each block is masked.

fixed_block_size
int | None

The block size to be used to perform attention. If specified, each sequence will be further divided to blocks, where each block has the maximum size of fixed_block_size. A query will only attend to the keys within the same block.

block_masks
dict[FLEX_ATTN_MASK_T, BlockMask] = {}
flex_attn
Callable[..., Any]
mask_key
FLEX_ATTN_MASK_T
nemo_automodel.components.attention.flex_attention.FlexAttention._fixed_block_mask_mod(
mask_mod: torch.nn.attention.flex_attention._mask_mod_signature,
fixed_block_size: int
) -> torch.nn.attention.flex_attention._mask_mod_signature
staticmethod

Given an arbitrary mask_mod, divide the input sequence to blocks and only allow attention within the same block.

Parameters:

mask_mod
_mask_mod_signature

The mask mod to apply to the documents

fixed_block_size
int

The number of tokens in each block.

nemo_automodel.components.attention.flex_attention.FlexAttention._get_block_causal_mask_mod(
batch: torch.Tensor,
eos_id: int
) -> torch.nn.attention.flex_attention._mask_mod_signature
staticmethod
nemo_automodel.components.attention.flex_attention.FlexAttention._get_causal_mask_mod() -> torch.nn.attention.flex_attention._mask_mod_signature
staticmethod
nemo_automodel.components.attention.flex_attention.FlexAttention._get_sliding_window_mask_mod(
window: int
) -> torch.nn.attention.flex_attention._mask_mod_signature
staticmethod

Returns a mask_mod function that

  • only allows kv_idx ≤ q_idx (causal)
  • and only if (q_idx - kv_idx) ≤ window
nemo_automodel.components.attention.flex_attention.FlexAttention.forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float | None = None,
sink_weights: torch.Tensor | None = None,
sliding_window: int = 0,
enable_gqa: bool = False
) -> torch.Tensor
nemo_automodel.components.attention.flex_attention.FLEX_ATTN_MASK_T = tuple[str, int | None] | tuple[int, int, int]