core.transformer.multi_latent_attention#
Module Contents#
Classes#
Submodules for the MLA self-attention layer. |
|
Multi-Latent Attention layer abstract class. |
|
MLA Self-attention layer class |
API#
- class core.transformer.multi_latent_attention.MLASelfAttentionSubmodules#
Submodules for the MLA self-attention layer.
- linear_q_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_q_down_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_q_up_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_kv_down_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_kv_up_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- kv_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
- class core.transformer.multi_latent_attention.MultiLatentAttention(
- config: megatron.core.transformer.transformer_config.MLATransformerConfig,
- submodules: Union[core.transformer.multi_latent_attention.MLASelfAttentionSubmodules],
- layer_number: int,
- attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
- attention_type: str,
- cp_comm_type: Optional[str] = None,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
Bases:
megatron.core.transformer.attention.AttentionMulti-Latent Attention layer abstract class.
This layer only contains common modules required for the “self attn” and “cross attn” specializations.
Initialization
- forward(
- hidden_states,
- attention_mask,
- key_value_states=None,
- inference_context=None,
- rotary_pos_emb=None,
- rotary_pos_cos=None,
- rotary_pos_sin=None,
- rotary_pos_cos_sin=None,
- attention_bias=None,
- packed_seq_params=None,
- position_ids=None,
- sequence_len_offset=None,
- *,
- inference_params=None,
Forward pass for multi-latent attention
- class core.transformer.multi_latent_attention.MLASelfAttention(
- config: megatron.core.transformer.transformer_config.MLATransformerConfig,
- submodules: core.transformer.multi_latent_attention.MLASelfAttentionSubmodules,
- layer_number: int,
- attn_mask_type=AttnMaskType.padding,
- cp_comm_type: Optional[str] = None,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
Bases:
core.transformer.multi_latent_attention.MultiLatentAttentionMLA Self-attention layer class
Self-attention layer takes input with size [s, b, h] and returns output of the same size.
Initialization
- get_query_key_value_tensors(
- hidden_states,
- key_value_states=None,
- position_ids=None,
- packed_seq_params=None,
- inference_context=None,
- *,
- inference_params=None,
Derives
query,keyandvaluetensors fromhidden_states.
- uncompress_kv_from_cache(kv_cached)#
Take a compressed kv and uncompress them
- prepare_for_absorption()#
Prepare the model for absorption optimization in MLA (Multi-Latent Attention).
This method sets up the necessary components for the absorption technique, which optimizes memory during inference by caching compressed KV latents instead of full KV states. The absorption technique allows efficient decode-only operations by pre-computing certain matrix multiplications.
Note (Peter): Right now we are not doing true absorption. We will add this support at a later time.
The method performs the following operations:
Splits the fused layernorm + linear layer (linear_kv_up_proj) into separate components.
Extracts and stores the up-projection weights for K and V separately, which are used during the absorption process
Replaces the identity kv_layernorm with the actual layernorm from the split
Stores the linear component separately for uncompressing KV cache during prefill/mixed stages
This is a one-time setup that should only be called once at initialization when cache_mla_latents is enabled.
- backward_dw() NoReturn#
Execute weight gradient computation
- _backward_kv_proj()#
Computes weight gradients of KV projection layers
- _backward_q_proj()#
Computes weight gradients of Q projection layers
- _backward_output_proj()#
Computes weight gradients of output projection layer
- set_for_recompute_input_layernorm()#
Set the attention layer for recompute input_layernorm. Only needed for fp8/fp4.
- clip_qk()#
QK Clipping is a technique to clip the query and key attention logits to prevent the attention logits from exploding. Per MuonClip usage, we update the weight by calling this function after Muon optimizer step.
- _clip_q_proj_weight(weight)#
Clip q_proj_weight
- _clip_kv_proj_weight(weight)#
Clip kv_proj_weight