nemo_automodel.components.attention.flex_attention
nemo_automodel.components.attention.flex_attention
Module Contents
Classes
Data
API
Bases: Module
FlexAttention module that uses torch.nn.attention.flex_attention.
This module is a wrapper around torch.nn.attention.flex_attention. This module implements certain common attention types, such as causal and block_causal.
Parameters:
The type of attention mask. Currently, we support “causal” and “block_causal”. “causal” means the lower triangle of the attention matrix is masked. “block_causal” means the attention matrix is divided into blocks, where block boundary is defined by EOS token, and the lower triangle of each block is masked.
The block size to be used to perform attention.
If specified, each sequence will be further divided to blocks, where each
block has the maximum size of fixed_block_size. A query will only attend
to the keys within the same block.
Given an arbitrary mask_mod, divide the input sequence to blocks and only allow attention within the same block.
Parameters:
The mask mod to apply to the documents
The number of tokens in each block.
Returns a mask_mod function that
- only allows kv_idx ≤ q_idx (causal)
- and only if (q_idx - kv_idx) ≤ window