core.transformer.dot_product_attention#
Module Contents#
Classes#
Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. |
API#
- class core.transformer.dot_product_attention.DotProductAttention(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- layer_number: int,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
- attention_type: str,
- attention_dropout: Optional[float] = None,
- softmax_scale: Optional[float] = None,
- cp_comm_type: Optional[str] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
megatron.core.transformer.module.MegatronModuleRegion where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
We use the following notation: h: hidden size n: number of attention heads p: number of tensor model parallel partitions b: batch size s: sequence length
Initialization
- forward(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- attn_mask_type: Optional[megatron.core.transformer.enums.AttnMaskType] = None,
- attention_bias: Optional[torch.Tensor] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
Forward.
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: Tuple[Tuple[int, int, int], ...] = (),
- metadata: Optional[dict] = None,
Sharded state dict for the learnable softmax offset parameter