nemo_automodel.components.attention.utils

View as Markdown

Module Contents

Functions

NameDescription
initialize_attn_module_and_funcInitialize an attention backend module and callable.
postprocess_output_for_attnPostprocess attention output based on attn_impl requirements.
preprocess_args_and_kwargs_for_attnPreprocess attention inputs based on backend requirements.

API

nemo_automodel.components.attention.utils.initialize_attn_module_and_func(
attn_impl: str,
num_attention_heads: int,
num_qk_channels: int,
num_v_channels: int,
softmax_scale: float,
attn_mask_type: str = 'causal',
qkv_format: str = 'bshd',
num_gqa_groups: int | None = None,
kwargs: typing.Any = {}
) -> tuple[torch.nn.Module | None, typing.Callable[..., torch.Tensor]]

Initialize an attention backend module and callable.

nemo_automodel.components.attention.utils.postprocess_output_for_attn(
x: torch.Tensor,
attn_impl: str
) -> torch.Tensor

Postprocess attention output based on attn_impl requirements.

nemo_automodel.components.attention.utils.preprocess_args_and_kwargs_for_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: torch.Tensor | None,
attn_impl: str,
kwargs: typing.Any = {}
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, typing.Any]]

Preprocess attention inputs based on backend requirements.