nemo_automodel.components.attention.flex_attention#

Module Contents#

Classes#

FlexAttention

FlexAttention module that uses torch.nn.attention.flex_attention.

Data#

API#

nemo_automodel.components.attention.flex_attention.FLEX_ATTN_MASK_T#

None

class nemo_automodel.components.attention.flex_attention.FlexAttention#

Bases: torch.nn.Module

FlexAttention module that uses torch.nn.attention.flex_attention.

This module is a wrapper around torch.nn.attention.flex_attention. This module implements certain common attention types, such as causal and block_causal.

Parameters:
  • attn_mask_type (str) – The type of attention mask. Currently, we support “causal” and “block_causal”. “causal” means the lower triangle of the attention matrix is masked. “block_causal” means the attention matrix is divided into blocks, where block boundary is defined by EOS token, and the lower triangle of each block is masked.

  • fixed_block_size (int | None) – The block size to be used to perform attention. If specified, each sequence will be further divided to blocks, where each block has the maximum size of fixed_block_size. A query will only attend to the keys within the same block.

flex_attn: ClassVar[Callable]#

‘compile(…)’

block_masks: ClassVar[dict[nemo_automodel.components.attention.flex_attention.FLEX_ATTN_MASK_T, torch.nn.attention.flex_attention.BlockMask]]#

None

forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float | None = None,
sink_weights: torch.Tensor | None = None,
sliding_window: int = 0,
enable_gqa: bool = False,
) torch.Tensor#
static _get_sliding_window_mask_mod(window: int)#

Returns a mask_mod function that

  • only allows kv_idx ≤ q_idx (causal)

  • and only if (q_idx - kv_idx) ≤ window

static _get_causal_mask_mod() torch.nn.attention.flex_attention._mask_mod_signature#
static _get_block_causal_mask_mod(
batch: torch.Tensor,
eos_id: int,
) torch.nn.attention.flex_attention._mask_mod_signature#
static _fixed_block_mask_mod(
mask_mod: torch.nn.attention.flex_attention._mask_mod_signature,
fixed_block_size: int,
) torch.nn.attention.flex_attention._mask_mod_signature#

Given an arbitrary mask_mod, divide the input sequence to blocks and only allow attention within the same block.

Parameters:
  • mask_mod – The mask mod to apply to the documents

  • fixed_block_size – The number of tokens in each block.