core.models.retro.base_attention#

Base class for decoder and encoder attention modules.

Module Contents#

Classes#

BaseRetroCrossAttention

Base class for Retro cross attention, for both encoder & decoder layers.

API#

class core.models.retro.base_attention.BaseRetroCrossAttention(
config: megatron.core.models.retro.config.RetroConfig,
submodules: megatron.core.transformer.attention.CrossAttentionSubmodules,
layer_number: int = 1,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType = AttnMaskType.padding,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

Base class for Retro cross attention, for both encoder & decoder layers.

This class collects the retro arguments below (i.e., num neighbors, chunk length, and retrieve length) for use in Retro’s custom cross attention operators.

Parameters:

Initialization