core.models.retro.base_attention#
Base class for decoder and encoder attention modules.
Module Contents#
Classes#
Base class for Retro cross attention, for both encoder & decoder layers. |
API#
- class core.models.retro.base_attention.BaseRetroCrossAttention(
- config: megatron.core.models.retro.config.RetroConfig,
- submodules: megatron.core.transformer.attention.CrossAttentionSubmodules,
- layer_number: int = 1,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType = AttnMaskType.padding,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
Bases:
megatron.core.transformer.module.MegatronModuleBase class for Retro cross attention, for both encoder & decoder layers.
This class collects the retro arguments below (i.e., num neighbors, chunk length, and retrieve length) for use in Retro’s custom cross attention operators.
- Parameters:
config (RetroConfig) – Retro config.
submodules (CrossAttentionSubmodules) – Cross attention submodules.
layer_number (int) – Layer number within transformer block.
attn_mask_type (AttnMaskType) – Mask type (‘causal’ or ‘padding’).
pg_collection (ProcessGroupCollection) – Model communication process groups.
Initialization