core.models.retro.encoder_attention#

Retro’s cross attention modules for the encoder block.

Module Contents#

Classes#

RetroEncoderCrossAttention

Retro encoder’s cross attention operator.

RetroEncoderBiasDropoutAdd

Retro encoder’s bias-dropout-add operator.

RetroEncoderLayerNorm

Retro encoder’s layernorm operator.

API#

class core.models.retro.encoder_attention.RetroEncoderCrossAttention(
config: megatron.core.models.retro.config.RetroConfig,
submodules: megatron.core.transformer.attention.CrossAttentionSubmodules,
layer_number: int = 1,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType = AttnMaskType.padding,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: megatron.core.models.retro.base_attention.BaseRetroCrossAttention

Retro encoder’s cross attention operator.

See this paper for more details: https://arxiv.org/abs/2112.04426. Neighboring chunks are retrieved from the chunk database, encoded, and used by the decoder layers for chunked cross attention.

Parameters:
  • config (RetroConfig) – Retro config.

  • submodules (CrossAttentionSubmodules) – Cross attention submodules.

  • layer_number (int) – Layer number within transformer block.

  • attn_mask_type (AttnMaskType) – Mask type (‘causal’ or ‘padding’).

Initialization

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
key_value_states: torch.Tensor = None,
inference_context: megatron.core.inference.contexts.BaseInferenceContext = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) List[Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]]#

Cross attention for Retro encoder.

Notation: ns : Sequence length. bs : Batch size. d : Hidden size. l : Number of chunks per sample (i.e., seq_length/chunk_length). k : Number of neighbors. r : Number of retrieved tokens (neighbors + continuation).

Parameters:
  • hidden_states (Tensor) – Transformer layer hidden states.

  • attention_mask (Tensor) – Attention mask.

  • key_value_states (Tensor) – Neighbor embeddings.

  • inference_context (BaseInferenceContext) – Inference context.

Returns:

List of tuples, where each tuple is (attention_output, attention_bias, residual).

class core.models.retro.encoder_attention.RetroEncoderBiasDropoutAdd(
config: megatron.core.models.retro.config.RetroConfig,
)#

Bases: megatron.core.transformer.module.MegatronModule

Retro encoder’s bias-dropout-add operator.

This operator applies bias-dropout-add individually on each neighboring chunk that is retrieved from the chunk database.

Parameters:

config (RetroConfig) – Retro config.

Initialization

classmethod _forward(
x_with_bias: List[Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]],
residual: torch.Tensor,
prob: float,
retro_num_neighbors: int,
bias_dropout_add: Callable,
) torch.Tensor#

Per-chunk bias-dropout-add.

Parameters:
  • x_with_bias (dict) – Attention output and bias tuple.

  • residual (Tensor) – Transformer layer residual.

  • prob (float) – Dropout probability.

  • retro_num_neighbors (int) – Number of retrieved neighbor chunks (e.g., 2).

  • bias_dropout_add (Callable) – Bias-dropout-add function.

Returns:

Output of bias-dropout-add.

forward(training: bool, fused: bool) functools.partial#

Retro decoder bias-dropout-add.

Parameters:
  • training (bool) – If training, then apply dropout.

  • fused (bool) – Fuse bias-dropout-add.

Returns:

A partial function for performing bias-dropout-add.

class core.models.retro.encoder_attention.RetroEncoderLayerNorm(
config: megatron.core.models.retro.config.RetroConfig,
submodules: Type,
**kwargs: dict,
)#

Bases: megatron.core.transformer.module.MegatronModule

Retro encoder’s layernorm operator.

This operator applies layernorm individually on each neighboring chunk that is retrieved from the chunk database, and then concatenates the chunks into a single tensor.

Parameters:
  • config (RetroConfig) – Retro config.

  • submodules (Type) – Layer norm class. (Named ‘submodules’ to fit external interface.)

Initialization

forward(input: torch.Tensor) torch.Tensor#

Per-chunk layer norm.

Parameters:

input (Tensor) – Input chunks, concatenated into a single tensor.

Returns:

Output of the layer norm.