core.transformer.attention#

Module Contents#

Classes#

SelfAttentionSubmodules

Configuration class for specifying the submodules of a self-attention.

CrossAttentionSubmodules

Configuration class for specifying the submodules of a cross-attention.

Attention

Attention layer abstract class.

SelfAttention

Self-attention layer class

CrossAttention

Cross-attention layer class

API#

class core.transformer.attention.SelfAttentionSubmodules#

Configuration class for specifying the submodules of a self-attention.

linear_qkv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

q_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

k_layernorm: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class core.transformer.attention.CrossAttentionSubmodules#

Configuration class for specifying the submodules of a cross-attention.

linear_q: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_kv: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

core_attention: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

linear_proj: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#

None

class core.transformer.attention.Attention(
config: core.transformer.transformer_config.TransformerConfig,
submodules: Union[core.transformer.attention.SelfAttentionSubmodules, core.transformer.attention.CrossAttentionSubmodules],
layer_number: int,
attn_mask_type: core.transformer.enums.AttnMaskType,
attention_type: str,
cp_comm_type: str = None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: megatron.core.transformer.module.MegatronModule, abc.ABC

Attention layer abstract class.

This layer only contains common modules required for the “self attn” and “cross attn” specializations.

Initialization

_checkpointed_attention_forward(
query,
key,
value,
attention_mask,
rotary_pos_emb=None,
attn_mask_type=None,
attention_bias=None,
packed_seq_params=None,
)#

Forward method with selective activation checkpointing.

_allocate_memory(
inference_max_sequence_length,
batch_size,
dim,
dtype,
)#

Allocate memory to store kv cache during inference.

_get_pp_layer_offset_for_inference()#

Return the pipeline parallel layer offset for inference.

_adjust_key_value_for_inference(
inference_context: megatron.core.inference.contexts.BaseInferenceContext,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
rotary_pos_cos_sin: Optional[torch.Tensor] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]#

Saves the generated key and value tensors to the end of the buffers in inference_context. Returns the full size keys and values from the provided inference_context, as well as adjusted rotary_pos_emb.

Parameters:
  • query (Tensor) – Query tensor.

  • key (Tensor) – Key tensor.

  • value (Tensor) – Value tensor.

  • rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]) – Rotary embedding tensor(s).

  • rotary_pos_cos (Optional[Tensor]) – Rotary embedding cosine.

  • rotary_pos_sin (Optional[Tensor]) – Rotary embedding sine.

  • rotary_pos_cos_sin (Optional[Tensor]) – Combined rotary embedding cosine and sine.

  • RoPE. (Currently used exclusively for inference with dynamic batching and flashinfer)

  • sequence_len_offset (Optional[int]) – Sequence length offset used for inference CUDA graphs.

Returns:

query, key, value, rotary_pos_emb, attn_mask_type, block_table.

Return type:

Tuple of

abstractmethod get_query_key_value_tensors(
hidden_states,
key_value_states,
split_qkv=True,
)#

This method needs to be implemented based on whether the derived class is “self-attn” or “cross-attn”.

flash_decode(
sequence_len_offset: torch.Tensor,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
inference_key_memory: torch.Tensor,
inference_value_memory: torch.Tensor,
rotary_cos: torch.Tensor,
rotary_sin: torch.Tensor,
rotary_interleaved: bool = False,
)#

The flash decoding kernel will do the following in a single execution:

  1. Compute RoPE embedding with precomputed cos & sin tensors

  2. Update the KV Cache

  3. Performs the flash attention operation

flash_decode_and_prefill(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
max_seqlen_q,
max_seqlen_k,
cu_seqlens_q,
cu_seqlens_k,
seqlens_k,
block_table,
) torch.Tensor#

Flash attention kernel for mixed decode and prefill samples.

Parameters:
  • q (Tensor) – Query tensor.

  • k (Tensor) – Key tensor.

  • v (Tensor) – Value tensor.

  • max_seqlen_q (int) – Query total sequence length.

  • max_seqlen_k (int) – Key total sequence length.

  • cu_seqlens_q (Tensor) – Cumulative query sequence lengths.

  • cu_seqlens_k (Tensor) – Cumulative key sequence lengths.

  • seqlens_k (Tensor) – key sequence lengths.

  • block_table (Tensor) – KV cache block ids for all samples.

Returns:

(Tensor) Attention output.

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
inference_context: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
rotary_pos_cos_sin: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
) Tuple[torch.Tensor, torch.Tensor]#

Perform a forward pass through the attention module.

Parameters:
  • hidden_states (Tensor) – Hidden states.

  • attention_mask (Tensor) – Attention mask.

  • key_value_states (Optional[Tensor]) – Key/value states (for cross attention).

  • inference_context (Optional[BaseInferenceContext]) – Inference context that manages KV cache.

  • rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]) – Rotary embedding tensor(s).

  • rotary_pos_cos (Optional[Tensor]) – Rotary embedding cosine.

  • rotary_pos_sin (Optional[Tensor]) – Rotary embedding sine.

  • rotary_pos_cos_sin (Optional[Tensor]) – Combined rotary embedding cosine and sine.

  • RoPE. (Currently used exclusively for inference with dynamic batching and flashinfer)

  • attention_bias (Optional[Tensor]) – Attention bias.

  • packed_seq_params (Optional[PackedSeqparams]) – Parameters used for THD format.

  • sequence_len_offset (Optional[int]) – Sequence length offset used for inference CUDA graphs.

Returns:

(Tuple[Tensor, Tensor]) Attention output and bias.

abstractmethod set_for_recompute_input_layernorm()#

Set the attention layer for recompute input_layernorm. Only needed for fp8.

abstractmethod clip_qk()#

QK Clipping is a technique to clip the query and key attention logits to prevent the attention logits from exploding.

class core.transformer.attention.SelfAttention(
config: core.transformer.transformer_config.TransformerConfig,
submodules: core.transformer.attention.SelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
cp_comm_type: str = None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: core.transformer.attention.Attention

Self-attention layer class

Self-attention layer takes input with size [s, b, h] and returns output of the same size.

Initialization

run_realtime_tests()#

Performs a consistency check.

This function makes sure that tensors across devices are the same during an experiment. This is often not guaranteed to be so because of silent hardware failures (eg, memory corruption loading a checkpoint, network traffic corruption encountered during data transmission).

(TODO) In the future, more tensors should be checked across the training run and checked every X iterations. This is left for future work. Equality of tensors is probably not required; transmitting hashes is sufficient.

get_query_key_value_tensors(
hidden_states,
key_value_states=None,
split_qkv=True,
)#

Derives query, key and value tensors from hidden_states. If split_qkv=False, then the unsplit mixed_qkv tensor is returned.

backward_dw() NoReturn#

Execute weight update operations

_backward_qkv_proj()#

Update weights for QKV projection layer

_backward_output_proj()#

Update weights for output projection layer

set_for_recompute_input_layernorm()#

Set the attention layer for recompute input_layernorm. Only needed for fp8/fp4.

clip_qk()#

QK Clipping is a technique to clip the query and key attention logits to prevent the attention logits from exploding. This function is experimental on GQA.

_clip_linear_qkv(weight)#

Apply qkclip to linear_qkv layer

class core.transformer.attention.CrossAttention(
config: core.transformer.transformer_config.TransformerConfig,
submodules: core.transformer.attention.CrossAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
cp_comm_type: str = None,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection = None,
)#

Bases: core.transformer.attention.Attention

Cross-attention layer class

Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size.

Initialization

get_query_key_value_tensors(
hidden_states,
key_value_states,
split_qkv=True,
)#

Derives query tensor from hidden_states, and key/value tensors from key_value_states.