nemo_automodel.components.models.common.combined_projection.combined_qkv#

Combined QKV attention projection for efficient multi-head attention.

This module provides a mixin class that enables combined QKV projection for any attention module, improving memory efficiency and reducing kernel launch overhead.

Module Contents#

Classes#

CombinedQKVAttentionMixin

Mixin for combined QKV projection in attention modules.

API#

class nemo_automodel.components.models.common.combined_projection.combined_qkv.CombinedQKVAttentionMixin#

Mixin for combined QKV projection in attention modules.

This mixin ALWAYS uses combined QKV projections for efficiency. Use this with custom transformer attention modules (Llama, Qwen2, etc.).

Usage: class MyAttention(CombinedQKVAttentionMixin, nn.Module): def init(self, config): super().init() # … other init code … self.setup_qkv_projection( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, head_dim=self.head_dim, bias=config.attention_bias )

    def forward(self, hidden_states, ...):
        query_states, key_states, value_states = self.compute_qkv(hidden_states)
        # ... rest of attention logic ...
setup_qkv_projection(
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
bias: bool = False,
use_combined_qkv: bool = True,
)#

Setup combined QKV projection (ALWAYS uses combined format).

Parameters:
  • hidden_size – Model hidden size

  • num_attention_heads – Number of attention heads

  • num_key_value_heads – Number of key/value heads (for GQA)

  • head_dim – Dimension per attention head

  • bias – Whether to use bias in projections

  • use_combined_qkv – DEPRECATED - always True for custom implementations

compute_qkv(
hidden_states: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Compute Q, K, V from hidden states using combined projection.

Handles tensor parallelism by dynamically computing split sizes based on actual tensor dimensions.

Parameters:

hidden_states – Input hidden states [batch, seq_len, hidden_size]

Returns:

Tuple of (query, key, value) tensors, each [batch, seq_len, …]