core.transformer.dot_product_attention#

Module Contents#

Classes#

DotProductAttention

Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.

API#

class core.transformer.dot_product_attention.DotProductAttention(
config: megatron.core.transformer.transformer_config.TransformerConfig,
layer_number: int,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
attention_type: str,
attention_dropout: Optional[float] = None,
softmax_scale: Optional[float] = None,
cp_comm_type: Optional[str] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.

We use the following notation: h: hidden size n: number of attention heads p: number of tensor model parallel partitions b: batch size s: sequence length

Initialization

forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
attn_mask_type: Optional[megatron.core.transformer.enums.AttnMaskType] = None,
attention_bias: Optional[torch.Tensor] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
)#

Forward.

sharded_state_dict(
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int], ...] = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Sharded state dict for the learnable softmax offset parameter