core.transformer.multi_latent_attention#

Module Contents#

Classes#

MLASelfAttentionSubmodules

Submodules for the MLA self-attention layer.

MultiLatentAttention

Multi-Latent Attention layer abstract class.

MLASelfAttention

MLA Self-attention layer class

API#

class core.transformer.multi_latent_attention.MLASelfAttentionSubmodules#

Submodules for the MLA self-attention layer.

linear_q_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_q_down_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_q_up_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_kv_down_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_kv_up_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

kv_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class core.transformer.multi_latent_attention.MultiLatentAttention(
config: megatron.core.transformer.transformer_config.MLATransformerConfig,
submodules: Union[core.transformer.multi_latent_attention.MLASelfAttentionSubmodules],
layer_number: int,
attn_mask_type: megatron.core.transformer.enums.AttnMaskType,
attention_type: str,
cp_comm_type: Optional[str] = None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: megatron.core.transformer.attention.Attention

Multi-Latent Attention layer abstract class.

This layer only contains common modules required for the “self attn” and “cross attn” specializations.

Initialization

forward(
hidden_states,
attention_mask,
key_value_states=None,
inference_context=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
rotary_pos_cos_sin=None,
attention_bias=None,
packed_seq_params=None,
position_ids=None,
sequence_len_offset=None,
*,
inference_params=None,
)#

Forward pass for multi-latent attention

class core.transformer.multi_latent_attention.MLASelfAttention(
config: megatron.core.transformer.transformer_config.MLATransformerConfig,
submodules: core.transformer.multi_latent_attention.MLASelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
cp_comm_type: Optional[str] = None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: core.transformer.multi_latent_attention.MultiLatentAttention

MLA Self-attention layer class

Self-attention layer takes input with size [s, b, h] and returns output of the same size.

Initialization

get_query_key_value_tensors(
hidden_states,
key_value_states=None,
position_ids=None,
packed_seq_params=None,
inference_context=None,
*,
inference_params=None,
)#

Derives query, key and value tensors from hidden_states.

uncompress_kv_from_cache(kv_cached)#

Take a compressed kv and uncompress them

prepare_for_absorption()#

Prepare the model for absorption optimization in MLA (Multi-Latent Attention).

This method sets up the necessary components for the absorption technique, which optimizes memory during inference by caching compressed KV latents instead of full KV states. The absorption technique allows efficient decode-only operations by pre-computing certain matrix multiplications.

Note (Peter): Right now we are not doing true absorption. We will add this support at a later time.

The method performs the following operations:

  1. Splits the fused layernorm + linear layer (linear_kv_up_proj) into separate components.

  2. Extracts and stores the up-projection weights for K and V separately, which are used during the absorption process

  3. Replaces the identity kv_layernorm with the actual layernorm from the split

  4. Stores the linear component separately for uncompressing KV cache during prefill/mixed stages

This is a one-time setup that should only be called once at initialization when cache_mla_latents is enabled.

backward_dw() NoReturn#

Execute weight gradient computation

_backward_kv_proj()#

Computes weight gradients of KV projection layers

_backward_q_proj()#

Computes weight gradients of Q projection layers

_backward_output_proj()#

Computes weight gradients of output projection layer

set_for_recompute_input_layernorm()#

Set the attention layer for recompute input_layernorm. Only needed for fp8/fp4.

clip_qk()#

QK Clipping is a technique to clip the query and key attention logits to prevent the attention logits from exploding. Per MuonClip usage, we update the weight by calling this function after Muon optimizer step.

_clip_q_proj_weight(weight)#

Clip q_proj_weight

_clip_kv_proj_weight(weight)#

Clip kv_proj_weight