core.transformer.attention#
Module Contents#
Classes#
Configuration class for specifying the submodules of a self-attention. |
|
Configuration class for specifying the submodules of a cross-attention. |
|
Attention layer abstract class. |
|
Self-attention layer class |
|
Cross-attention layer class |
API#
- class core.transformer.attention.SelfAttentionSubmodules#
Configuration class for specifying the submodules of a self-attention.
- linear_qkv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- k_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- class core.transformer.attention.CrossAttentionSubmodules#
Configuration class for specifying the submodules of a cross-attention.
- linear_q: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_kv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- class core.transformer.attention.Attention(
- config: core.transformer.transformer_config.TransformerConfig,
- submodules: Union[core.transformer.attention.SelfAttentionSubmodules, core.transformer.attention.CrossAttentionSubmodules],
- layer_number: int,
- attn_mask_type: core.transformer.enums.AttnMaskType,
- attention_type: str,
- cp_comm_type: str = None,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
Bases:
megatron.core.transformer.module.MegatronModule,abc.ABCAttention layer abstract class.
This layer only contains common modules required for the “self attn” and “cross attn” specializations.
Initialization
- _checkpointed_attention_forward(
- query,
- key,
- value,
- attention_mask,
- rotary_pos_emb=None,
- attn_mask_type=None,
- attention_bias=None,
- packed_seq_params=None,
Forward method with selective activation checkpointing.
- _allocate_memory(
- inference_max_sequence_length,
- batch_size,
- dim,
- dtype,
Allocate memory to store kv cache during inference.
- _get_pp_layer_offset_for_inference()#
Return the pipeline parallel layer offset for inference.
- _adjust_key_value_for_inference(
- inference_context: megatron.core.inference.contexts.BaseInferenceContext,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- rotary_pos_emb: torch.Tensor,
- rotary_pos_cos: Optional[torch.Tensor] = None,
- rotary_pos_sin: Optional[torch.Tensor] = None,
- rotary_pos_cos_sin: Optional[torch.Tensor] = None,
- sequence_len_offset: Optional[int] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Saves the generated key and value tensors to the end of the buffers in inference_context. Returns the full size keys and values from the provided inference_context, as well as adjusted rotary_pos_emb.
- Parameters:
query (Tensor) – Query tensor.
key (Tensor) – Key tensor.
value (Tensor) – Value tensor.
rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]) – Rotary embedding tensor(s).
rotary_pos_cos (Optional[Tensor]) – Rotary embedding cosine.
rotary_pos_sin (Optional[Tensor]) – Rotary embedding sine.
rotary_pos_cos_sin (Optional[Tensor]) – Combined rotary embedding cosine and sine.
RoPE. (Currently used exclusively for inference with dynamic batching and flashinfer)
sequence_len_offset (Optional[int]) – Sequence length offset used for inference CUDA graphs.
- Returns:
query, key, value, rotary_pos_emb, attn_mask_type, block_table.
- Return type:
Tuple of
- abstractmethod get_query_key_value_tensors(
- hidden_states,
- key_value_states,
- split_qkv=True,
This method needs to be implemented based on whether the derived class is “self-attn” or “cross-attn”.
- flash_decode(
- sequence_len_offset: torch.Tensor,
- query_layer: torch.Tensor,
- key_layer: torch.Tensor,
- value_layer: torch.Tensor,
- inference_key_memory: torch.Tensor,
- inference_value_memory: torch.Tensor,
- rotary_cos: torch.Tensor,
- rotary_sin: torch.Tensor,
- rotary_interleaved: bool = False,
The flash decoding kernel will do the following in a single execution:
Compute RoPE embedding with precomputed cos & sin tensors
Update the KV Cache
Performs the flash attention operation
- flash_decode_and_prefill(
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- max_seqlen_q,
- max_seqlen_k,
- cu_seqlens_q,
- cu_seqlens_k,
- seqlens_k,
- block_table,
Flash attention kernel for mixed decode and prefill samples.
- Parameters:
q (Tensor) – Query tensor.
k (Tensor) – Key tensor.
v (Tensor) – Value tensor.
max_seqlen_q (int) – Query total sequence length.
max_seqlen_k (int) – Key total sequence length.
cu_seqlens_q (Tensor) – Cumulative query sequence lengths.
cu_seqlens_k (Tensor) – Cumulative key sequence lengths.
seqlens_k (Tensor) – key sequence lengths.
block_table (Tensor) – KV cache block ids for all samples.
- Returns:
(Tensor) Attention output.
- forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
- rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
- rotary_pos_cos: Optional[torch.Tensor] = None,
- rotary_pos_sin: Optional[torch.Tensor] = None,
- rotary_pos_cos_sin: Optional[torch.Tensor] = None,
- attention_bias: Optional[torch.Tensor] = None,
- packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
- sequence_len_offset: Optional[int] = None,
- *,
- inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
Perform a forward pass through the attention module.
- Parameters:
hidden_states (Tensor) – Hidden states.
attention_mask (Tensor) – Attention mask.
key_value_states (Optional[Tensor]) – Key/value states (for cross attention).
inference_context (Optional[BaseInferenceContext]) – Inference context that manages KV cache.
rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]) – Rotary embedding tensor(s).
rotary_pos_cos (Optional[Tensor]) – Rotary embedding cosine.
rotary_pos_sin (Optional[Tensor]) – Rotary embedding sine.
rotary_pos_cos_sin (Optional[Tensor]) – Combined rotary embedding cosine and sine.
RoPE. (Currently used exclusively for inference with dynamic batching and flashinfer)
attention_bias (Optional[Tensor]) – Attention bias.
packed_seq_params (Optional[PackedSeqparams]) – Parameters used for THD format.
sequence_len_offset (Optional[int]) – Sequence length offset used for inference CUDA graphs.
- Returns:
(Tuple[Tensor, Tensor]) Attention output and bias.
- abstractmethod set_for_recompute_input_layernorm()#
Set the attention layer for recompute input_layernorm. Only needed for fp8.
- abstractmethod clip_qk()#
QK Clipping is a technique to clip the query and key attention logits to prevent the attention logits from exploding.
- class core.transformer.attention.SelfAttention(
- config: core.transformer.transformer_config.TransformerConfig,
- submodules: core.transformer.attention.SelfAttentionSubmodules,
- layer_number: int,
- attn_mask_type=AttnMaskType.padding,
- cp_comm_type: str = None,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
Bases:
core.transformer.attention.AttentionSelf-attention layer class
Self-attention layer takes input with size [s, b, h] and returns output of the same size.
Initialization
- run_realtime_tests()#
Performs a consistency check.
This function makes sure that tensors across devices are the same during an experiment. This is often not guaranteed to be so because of silent hardware failures (eg, memory corruption loading a checkpoint, network traffic corruption encountered during data transmission).
(TODO) In the future, more tensors should be checked across the training run and checked every X iterations. This is left for future work. Equality of tensors is probably not required; transmitting hashes is sufficient.
- get_query_key_value_tensors(
- hidden_states,
- key_value_states=None,
- split_qkv=True,
Derives
query,keyandvaluetensors fromhidden_states. Ifsplit_qkv=False, then the unsplit mixed_qkv tensor is returned.
- backward_dw() NoReturn#
Execute weight update operations
- _backward_qkv_proj()#
Update weights for QKV projection layer
- _backward_output_proj()#
Update weights for output projection layer
- set_for_recompute_input_layernorm()#
Set the attention layer for recompute input_layernorm. Only needed for fp8/fp4.
- clip_qk()#
QK Clipping is a technique to clip the query and key attention logits to prevent the attention logits from exploding. This function is experimental on GQA.
- _clip_linear_qkv(weight)#
Apply qkclip to linear_qkv layer
- class core.transformer.attention.CrossAttention(
- config: core.transformer.transformer_config.TransformerConfig,
- submodules: core.transformer.attention.CrossAttentionSubmodules,
- layer_number: int,
- attn_mask_type=AttnMaskType.padding,
- cp_comm_type: str = None,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
Bases:
core.transformer.attention.AttentionCross-attention layer class
Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size.
Initialization
- get_query_key_value_tensors(
- hidden_states,
- key_value_states,
- split_qkv=True,
Derives
querytensor fromhidden_states, andkey/valuetensors fromkey_value_states.