core.models.retro.decoder_attention#

Retro’s cross attention modules for the decoder block.

Module Contents#

Classes#

RetroDecoderCrossAttention

Retro decoder’s chunked cross attention operator.

RetroDecoderBiasDropoutAdd

Retro decoder’s bias-dropout-add operator.

API#

class core.models.retro.decoder_attention.RetroDecoderCrossAttention(
config: megatron.core.models.retro.config.RetroConfig,
submodules: megatron.core.transformer.attention.CrossAttentionSubmodules,
layer_number: int = 1,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType = AttnMaskType.padding,
encoder_block_spec: megatron.core.transformer.ModuleSpec = None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: megatron.core.models.retro.base_attention.BaseRetroCrossAttention

Retro decoder’s chunked cross attention operator.

See this paper for more details: https://arxiv.org/abs/2112.04426. Neighboring chunks retrieved from the chunk database are used here for chunked-cross attention.

** Note about ‘encoder_block_spec’ **

Retro is an encoder-decoder model that uses its encoder for encoding neighboring chunks that are retrieved from a chunk database. These encoded neighbors are then used in the decoder stack for performing chunked-cross attention (see paper link above).

In contrast to the T5 model, the encoder and decoder are computationally intertwined, since the input to the encoder is the output of the self- attention of the first decoder layer. As such, the encoder block itself is instantiated within the first Retro decoder layer, in order to receive the self-attention’s output. (Note, that only the first decoder layer instantiates an encoder block, and the remaining decoder layers use the encoder output from the first decoder layer.)

Parameters:
  • config (RetroConfig) – Retro config.

  • submodules (CrossAttentionSubmodules) – Cross attention submodules.

  • layer_number (int) – Layer number within transformer block.

  • attn_mask_type (AttnMaskType) – Mask type (‘causal’ or ‘padding’).

  • encoder_block_spec (ModuleSpec) – The first Retro decoder layer is provided with a transformer block spec to construct the neighbor encoder.

  • pg_collection (ProcessGroupCollection) – Model communication process groups.

Initialization

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
key_value_states: torch.Tensor = None,
inference_context: megatron.core.inference.contexts.BaseInferenceContext = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) dict#

Cross attention for Retro decoder.

Notation: ns : Sequence length. bs : Batch size. d : Hidden size. l : Number of chunks per sample (i.e., seq_length/chunk_length). m : Number of tokens per chunk. k : Number of neighbors. r : Number of retrieved tokens (neighbors + continuation).

Parameters:
  • hidden_states (Tensor) – Transformer layer hidden states.

  • attention_mask (Tensor) – Attention mask.

  • key_value_states (Tensor) – Neighbor embeddings if first decoder layer, else encoder output.

  • inference_context (BaseInferenceContext) – Inference context.

Returns:

A dict consisting of the attention output and context, along with other scalars necessary for performing the downstream bias-dropout-add.

class core.models.retro.decoder_attention.RetroDecoderBiasDropoutAdd(
config: megatron.core.models.retro.config.RetroConfig,
)#

Bases: megatron.core.transformer.module.MegatronModule

Retro decoder’s bias-dropout-add operator.

This operator takes care of reshaping and permuting the output from the chunk dimension to the sequence dimension.

Parameters:

config (RetroConfig) – Retro config.

Initialization

classmethod _forward(
x_with_bias: dict,
residual: torch.Tensor,
prob: float,
retro_chunk_length: int,
bias_dropout_add: Callable,
) torch.Tensor#

Per-chunk bias-dropout-add.

Parameters:
  • x_with_bias (dict) – Attention output and bias, along with other Retro relevant parameters.

  • residual (Tensor) – Transformer layer residual.

  • prob (float) – Dropout probability.

  • retro_chunk_length (int) – Retro chunk length (e.g., 64).

  • bias_dropout_add (Callable) – Bias-dropout-add function.

Returns:

Output of bias-dropout-add.

forward(training: bool, fused: bool) functools.partial#

Retro decoder bias-dropout-add.

Parameters:
  • training (bool) – If training, then apply dropout.

  • fused (bool) – Fuse bias-dropout-add.

Returns:

The partial function for performing bias-dropout-add.