nemo_automodel.components.models.common.combined_projection.combined_qkv#
Combined QKV attention projection for efficient multi-head attention.
This module provides a mixin class that enables combined QKV projection for any attention module, improving memory efficiency and reducing kernel launch overhead.
Module Contents#
Classes#
Mixin for combined QKV projection in attention modules. |
Functions#
Verify that a combined-projection weight uses ColwiseParallel (Shard(0)) if TP is active. |
API#
- nemo_automodel.components.models.common.combined_projection.combined_qkv._assert_colwise_parallel(weight: torch.Tensor, name: str) None#
Verify that a combined-projection weight uses ColwiseParallel (Shard(0)) if TP is active.
Shard(dim=1) is expected from FSDP pre-sharding of combined projections and is excluded from the check — it is temporary and undone by FSDP all-gather before the actual matmul.
- class nemo_automodel.components.models.common.combined_projection.combined_qkv.CombinedQKVAttentionMixin#
Mixin for combined QKV projection in attention modules.
This mixin ALWAYS uses combined QKV projections for efficiency. Use this with custom transformer attention modules (Llama, Qwen2, etc.).
Usage: class MyAttention(CombinedQKVAttentionMixin, nn.Module): def init(self, config): super().init() # … other init code … self.setup_qkv_projection( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, head_dim=self.head_dim, bias=config.attention_bias )
def forward(self, hidden_states, ...): query_states, key_states, value_states = self.compute_qkv(hidden_states) # ... rest of attention logic ...- setup_qkv_projection(
- hidden_size: int,
- num_attention_heads: int,
- num_key_value_heads: int,
- head_dim: int,
- bias: bool = False,
- use_combined_qkv: bool = True,
Setup combined QKV projection (ALWAYS uses combined format).
- Parameters:
hidden_size – Model hidden size
num_attention_heads – Number of attention heads
num_key_value_heads – Number of key/value heads (for GQA)
head_dim – Dimension per attention head
bias – Whether to use bias in projections
use_combined_qkv – DEPRECATED - always True for custom implementations
- compute_qkv(
- hidden_states: torch.Tensor,
Compute Q, K, V from hidden states using combined projection.
The QKV weight uses a KV-head-grouped interleaved layout: [Q_group_0 | K_0 | V_0 | Q_group_1 | K_1 | V_1 | …] This ensures ColwiseParallel TP sharding gives each rank complete KV-head groups. We split within each group (a local operation).
- Parameters:
hidden_states – Input hidden states [batch, seq_len, hidden_size]
- Returns:
Tuple of (query, key, value) tensors, each [batch, seq_len, …]