core.transformer.attention#

Module Contents#

Classes#

LinearQkv

Protocol for linear_qkv modules.

LinearQkvBuilder

Protocol for building linear_qkv layers.

LinearLayer

Protocol for linear_q and linear_kv modules.

LinearLayerBuilder

Protocol for building linear_q and linear_kv layers.

CoreAttention

Protocol for core_attention modules.

CoreAttentionBuilder

Protocol for building core_attention layers.

SelfAttentionSubmodules

Configuration class for specifying the submodules of a self-attention.

CrossAttentionSubmodules

Configuration class for specifying the submodules of a cross-attention.

Attention

Attention layer abstract class.

SelfAttention

Self-attention layer class

CrossAttention

Cross-attention layer class

API#

class core.transformer.attention.LinearQkv#

Bases: typing.Protocol

Protocol for linear_qkv modules.

forward(input: torch.Tensor, /) tuple[torch.Tensor, object]#

Applies linear_qkv.

backward_dw() None#

Backward pass for the linear_qkv module.

class core.transformer.attention.LinearQkvBuilder#

Bases: typing.Protocol

Protocol for building linear_qkv layers.

__call__(
input_size: int,
output_size: int,
/,
*,
config: core.transformer.transformer_config.TransformerConfig,
init_method: Callable[[torch.Tensor], None],
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str,
tp_group: torch.distributed.ProcessGroup | None = None,
) core.transformer.attention.LinearQkv#
class core.transformer.attention.LinearLayer#

Bases: typing.Protocol

Protocol for linear_q and linear_kv modules.

forward(input: torch.Tensor, /) Tuple[torch.Tensor, object]#

Applies linear_q/linear_kv.

class core.transformer.attention.LinearLayerBuilder#

Bases: typing.Protocol

Protocol for building linear_q and linear_kv layers.

__call__(
input_size: int,
output_size: int,
/,
*,
config: core.transformer.transformer_config.TransformerConfig,
init_method: Callable[[torch.Tensor], None],
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
) core.transformer.attention.LinearLayer#
class core.transformer.attention.CoreAttention#

Bases: typing.Protocol

Protocol for core_attention modules.

forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
/,
*,
attn_mask_type: core.transformer.enums.AttnMaskType,
attention_bias: Optional[torch.Tensor],
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams],
) torch.Tensor#

Applies dot product attention.

class core.transformer.attention.CoreAttentionBuilder#

Bases: typing.Protocol

Protocol for building core_attention layers.

__call__(
*,
config: core.transformer.transformer_config.TransformerConfig,
layer_number: int,
attn_mask_type: core.transformer.enums.AttnMaskType,
attention_type: str,
cp_comm_type: Optional[str],
softmax_scale: Optional[float],
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection],
) core.transformer.attention.CoreAttention#
class core.transformer.attention.SelfAttentionSubmodules#

Configuration class for specifying the submodules of a self-attention.

linear_qkv: core.transformer.attention.LinearQkvBuilder#

None

core_attention: core.transformer.attention.CoreAttentionBuilder#

None

linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

k_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class core.transformer.attention.CrossAttentionSubmodules#

Configuration class for specifying the submodules of a cross-attention.

linear_q: core.transformer.attention.LinearLayerBuilder#

None

linear_kv: core.transformer.attention.LinearLayerBuilder#

None

core_attention: core.transformer.attention.CoreAttentionBuilder#

None

linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class core.transformer.attention.Attention(
config: core.transformer.transformer_config.TransformerConfig,
submodules: Union[core.transformer.attention.SelfAttentionSubmodules, core.transformer.attention.CrossAttentionSubmodules],
layer_number: int,
attn_mask_type: core.transformer.enums.AttnMaskType,
attention_type: str,
cp_comm_type: str | None = None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection | None = None,
)#

Bases: megatron.core.transformer.module.MegatronModule, abc.ABC

Attention layer abstract class.

This layer only contains common modules required for the “self attn” and “cross attn” specializations.

Initialization

_checkpointed_attention_forward(
query,
key,
value,
attention_mask,
rotary_pos_emb=None,
attn_mask_type=None,
attention_bias=None,
packed_seq_params=None,
)#

Forward method with selective activation checkpointing.

_allocate_memory(
inference_max_sequence_length,
batch_size,
dim,
dtype,
)#

Allocate memory to store kv cache during inference.

_get_pp_layer_offset_for_inference()#

Return the pipeline parallel layer offset for inference.

_adjust_key_value_for_inference(
inference_context: megatron.core.inference.contexts.BaseInferenceContext,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
rotary_pos_cos_sin: Optional[torch.Tensor] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, core.transformer.enums.AttnMaskType, torch.Tensor]#

Saves the generated key and value tensors to the end of the buffers in inference_context. Returns the full size keys and values from the provided inference_context, as well as adjusted rotary_pos_emb.

Parameters:
  • query (Tensor) – Query tensor.

  • key (Tensor) – Key tensor.

  • value (Tensor) – Value tensor.

  • rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]) – Rotary embedding tensor(s).

  • rotary_pos_cos (Optional[Tensor]) – Rotary embedding cosine.

  • rotary_pos_sin (Optional[Tensor]) – Rotary embedding sine.

  • rotary_pos_cos_sin (Optional[Tensor]) – Combined rotary embedding cosine and sine.

  • RoPE. (Currently used exclusively for inference with dynamic batching and flashinfer)

  • sequence_len_offset (Optional[int]) – Sequence length offset used for inference CUDA graphs.

Returns:

query, key, value, rotary_pos_emb, attn_mask_type, block_table.

Return type:

Tuple of

abstractmethod get_query_key_value_tensors(
hidden_states: torch.Tensor,
key_value_states: torch.Tensor | None,
output_gate: bool = False,
split_qkv: bool = True,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, list[int]]#

This method needs to be implemented based on whether the derived class is “self-attn” or “cross-attn”.

flash_decode(
sequence_len_offset: torch.Tensor,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
inference_key_memory: torch.Tensor,
inference_value_memory: torch.Tensor,
rotary_cos: torch.Tensor,
rotary_sin: torch.Tensor,
rotary_interleaved: bool = False,
) tuple[torch.Tensor, torch.Tensor]#

The flash decoding kernel will do the following in a single execution:

  1. Compute RoPE embedding with precomputed cos & sin tensors

  2. Update the KV Cache

  3. Performs the flash attention operation

_flash_attention_3_forward_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
max_seqlen_q,
max_seqlen_k,
cu_seqlens_q,
seqlens_k,
block_table,
softmax_scale,
)#

Wrapper for calling the FA3 _flash_attn_forward function. Handles argument conversion for different versions of the _flash_attn_forward API.

flash_decode_and_prefill(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
max_seqlen_q,
max_seqlen_k,
cu_seqlens_q,
cu_seqlens_k,
seqlens_k,
block_table,
is_decode_only,
) torch.Tensor#

Flash attention kernel for mixed decode and prefill samples.

Parameters:
  • q (Tensor) – Query tensor.

  • k (Tensor) – Key tensor.

  • v (Tensor) – Value tensor.

  • max_seqlen_q (int) – Query total sequence length.

  • max_seqlen_k (int) – Key total sequence length.

  • cu_seqlens_q (Tensor) – Cumulative query sequence lengths.

  • cu_seqlens_k (Tensor) – Cumulative key sequence lengths.

  • seqlens_k (Tensor) – key sequence lengths.

  • block_table (Tensor) – KV cache block ids for all samples.

  • is_decode_only (bool) – True if batch is decode only.

Returns:

(Tensor) Attention output.

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
rotary_pos_cos_sin: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) tuple[torch.Tensor, torch.Tensor]#

Perform a forward pass through the attention module.

Parameters:
  • hidden_states (Tensor) – Hidden states.

  • attention_mask (Tensor) – Attention mask.

  • key_value_states (Optional[Tensor]) – Key/value states (for cross attention).

  • inference_context (Optional[BaseInferenceContext]) – Inference context that manages KV cache.

  • rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]) – Rotary embedding tensor(s).

  • rotary_pos_cos (Optional[Tensor]) – Rotary embedding cosine.

  • rotary_pos_sin (Optional[Tensor]) – Rotary embedding sine.

  • rotary_pos_cos_sin (Optional[Tensor]) – Combined rotary embedding cosine and sine.

  • RoPE. (Currently used exclusively for inference with dynamic batching and flashinfer)

  • attention_bias (Optional[Tensor]) – Attention bias.

  • packed_seq_params (Optional[PackedSeqparams]) – Parameters used for THD format.

  • sequence_len_offset (Optional[int]) – Sequence length offset used for inference CUDA graphs.

Returns:

(Tuple[Tensor, Tensor]) Attention output and bias.

_apply_output_gate(x, gate)#
abstractmethod set_for_recompute_input_layernorm()#

Set the attention layer for recompute input_layernorm. Only needed for fp8.

abstractmethod clip_qk()#

QK Clipping is a technique to clip the query and key attention logits to prevent the attention logits from exploding.

class core.transformer.attention.SelfAttention(
config: core.transformer.transformer_config.TransformerConfig,
submodules: core.transformer.attention.SelfAttentionSubmodules,
layer_number: int,
attn_mask_type: core.transformer.enums.AttnMaskType = AttnMaskType.padding,
cp_comm_type: str | None = None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection | None = None,
)#

Bases: core.transformer.attention.Attention

Self-attention layer class

Self-attention layer takes input with size [s, b, h] and returns output of the same size.

Initialization

run_realtime_tests()#

Performs a consistency check.

This function makes sure that tensors across devices are the same during an experiment. This is often not guaranteed to be so because of silent hardware failures (eg, memory corruption loading a checkpoint, network traffic corruption encountered during data transmission).

(TODO) In the future, more tensors should be checked across the training run and checked every X iterations. This is left for future work. Equality of tensors is probably not required; transmitting hashes is sufficient.

get_query_key_value_tensors(
hidden_states: torch.Tensor,
key_value_states: torch.Tensor | None = None,
output_gate: bool = False,
split_qkv: bool = True,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, list[int]]#

Derives query, key and value tensors from hidden_states. If output_gate is True, then also derives gate tensor. If split_qkv=False, then the unsplit mixed_qkv tensor is returned.

backward_dw() None#

Execute weight update operations

_backward_qkv_proj()#

Update weights for QKV projection layer

_backward_output_proj()#

Update weights for output projection layer

set_for_recompute_input_layernorm()#

Set the attention layer for recompute input_layernorm. Only needed for fp8/fp4.

clip_qk()#

QK Clipping is a technique to clip the query and key attention logits to prevent the attention logits from exploding. This function is experimental on GQA.

_clip_linear_qkv(weight)#

Apply qkclip to linear_qkv layer

class core.transformer.attention.CrossAttention(
config: core.transformer.transformer_config.TransformerConfig,
submodules: core.transformer.attention.CrossAttentionSubmodules,
layer_number: int,
attn_mask_type: core.transformer.enums.AttnMaskType = AttnMaskType.padding,
cp_comm_type: str | None = None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection | None = None,
)#

Bases: core.transformer.attention.Attention

Cross-attention layer class

Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size.

Initialization

get_query_key_value_tensors(
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor],
output_gate: bool = False,
split_qkv: bool = True,
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Derives query tensor from hidden_states, and key/value tensors from key_value_states.