nemo_automodel.components.models.common.combined_projection.combined_qkv#
Combined QKV attention projection for efficient multi-head attention.
This module provides a mixin class that enables combined QKV projection for any attention module, improving memory efficiency and reducing kernel launch overhead.
Module Contents#
Classes#
Mixin for combined QKV projection in attention modules. |
API#
- class nemo_automodel.components.models.common.combined_projection.combined_qkv.CombinedQKVAttentionMixin#
Mixin for combined QKV projection in attention modules.
This mixin ALWAYS uses combined QKV projections for efficiency. Use this with custom transformer attention modules (Llama, Qwen2, etc.).
Usage: class MyAttention(CombinedQKVAttentionMixin, nn.Module): def init(self, config): super().init() # … other init code … self.setup_qkv_projection( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, head_dim=self.head_dim, bias=config.attention_bias )
def forward(self, hidden_states, ...): query_states, key_states, value_states = self.compute_qkv(hidden_states) # ... rest of attention logic ...- setup_qkv_projection(
- hidden_size: int,
- num_attention_heads: int,
- num_key_value_heads: int,
- head_dim: int,
- bias: bool = False,
- use_combined_qkv: bool = True,
Setup combined QKV projection (ALWAYS uses combined format).
- Parameters:
hidden_size – Model hidden size
num_attention_heads – Number of attention heads
num_key_value_heads – Number of key/value heads (for GQA)
head_dim – Dimension per attention head
bias – Whether to use bias in projections
use_combined_qkv – DEPRECATED - always True for custom implementations
- compute_qkv(
- hidden_states: torch.Tensor,
Compute Q, K, V from hidden states using combined projection.
Handles tensor parallelism by dynamically computing split sizes based on actual tensor dimensions.
- Parameters:
hidden_states – Input hidden states [batch, seq_len, hidden_size]
- Returns:
Tuple of (query, key, value) tensors, each [batch, seq_len, …]