nemo_automodel.components.attention.flex_attention
#
Module Contents#
Classes#
FlexAttention module that uses torch.nn.attention.flex_attention. |
Data#
API#
- nemo_automodel.components.attention.flex_attention.FLEX_ATTN_MASK_T#
None
- class nemo_automodel.components.attention.flex_attention.FlexAttention#
Bases:
torch.nn.Module
FlexAttention module that uses torch.nn.attention.flex_attention.
This module is a wrapper around torch.nn.attention.flex_attention. This module implements certain common attention types, such as causal and block_causal.
- Parameters:
attn_mask_type (str) – The type of attention mask. Currently, we support “causal” and “block_causal”. “causal” means the lower triangle of the attention matrix is masked. “block_causal” means the attention matrix is divided into blocks, where block boundary is defined by EOS token, and the lower triangle of each block is masked.
fixed_block_size (int | None) – The block size to be used to perform attention. If specified, each sequence will be further divided to blocks, where each block has the maximum size of
fixed_block_size
. A query will only attend to the keys within the same block.
- flex_attn: ClassVar[Callable]#
‘compile(…)’
- block_masks: ClassVar[dict[nemo_automodel.components.attention.flex_attention.FLEX_ATTN_MASK_T, torch.nn.attention.flex_attention.BlockMask]]#
None
- forward(
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- scale: float | None = None,
- sink_weights: torch.Tensor | None = None,
- sliding_window: int = 0,
- enable_gqa: bool = False,
- static _get_sliding_window_mask_mod(window: int)#
Returns a mask_mod function that
only allows kv_idx ≤ q_idx (causal)
and only if (q_idx - kv_idx) ≤ window
- static _get_causal_mask_mod() torch.nn.attention.flex_attention._mask_mod_signature #
- static _get_block_causal_mask_mod(
- batch: torch.Tensor,
- eos_id: int,
- static _fixed_block_mask_mod(
- mask_mod: torch.nn.attention.flex_attention._mask_mod_signature,
- fixed_block_size: int,
Given an arbitrary mask_mod, divide the input sequence to blocks and only allow attention within the same block.
- Parameters:
mask_mod – The mask mod to apply to the documents
fixed_block_size – The number of tokens in each block.