nemo_automodel.components.attention.utils#

Module Contents#

Functions#

initialize_attn_module_and_func

preprocess_args_and_kwargs_for_attn

Preprocess attention inputs based on backend requirements.

postprocess_output_for_attn

Postprocess attention output based on attn_impl requirements.

API#

nemo_automodel.components.attention.utils.initialize_attn_module_and_func(
attn_impl: str,
num_attention_heads: int,
num_qk_channels: int,
num_v_channels: int,
softmax_scale: float,
attn_mask_type: str = 'causal',
qkv_format: str = 'bshd',
num_gqa_groups: int | None = None,
**kwargs,
) tuple[torch.nn.Module | None, Callable]#
nemo_automodel.components.attention.utils.preprocess_args_and_kwargs_for_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: torch.Tensor | None,
attn_impl: str,
**kwargs,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]#

Preprocess attention inputs based on backend requirements.

nemo_automodel.components.attention.utils.postprocess_output_for_attn(
x: torch.Tensor,
attn_impl: str,
) torch.Tensor#

Postprocess attention output based on attn_impl requirements.