nemo_automodel.components.attention.utils
#
Module Contents#
Functions#
Preprocess attention inputs based on backend requirements. |
|
Postprocess attention output based on attn_impl requirements. |
API#
- nemo_automodel.components.attention.utils.initialize_attn_module_and_func(
- attn_impl: str,
- num_attention_heads: int,
- num_qk_channels: int,
- num_v_channels: int,
- softmax_scale: float,
- attn_mask_type: str = 'causal',
- qkv_format: str = 'bshd',
- num_gqa_groups: int | None = None,
- **kwargs,
- nemo_automodel.components.attention.utils.preprocess_args_and_kwargs_for_attn(
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- attention_mask: torch.Tensor | None,
- attn_impl: str,
- **kwargs,
Preprocess attention inputs based on backend requirements.
- nemo_automodel.components.attention.utils.postprocess_output_for_attn(
- x: torch.Tensor,
- attn_impl: str,
Postprocess attention output based on attn_impl requirements.