nemo_automodel.components.models.common.combined_projection.combined_qkv#

Combined QKV attention projection for efficient multi-head attention.

This module provides a mixin class that enables combined QKV projection for any attention module, improving memory efficiency and reducing kernel launch overhead.

Module Contents#

Classes#

CombinedQKVAttentionMixin

Mixin for combined QKV projection in attention modules.

Functions#

_assert_colwise_parallel

Verify that a combined-projection weight uses ColwiseParallel (Shard(0)) if TP is active.

API#

nemo_automodel.components.models.common.combined_projection.combined_qkv._assert_colwise_parallel(weight: torch.Tensor, name: str) None#

Verify that a combined-projection weight uses ColwiseParallel (Shard(0)) if TP is active.

Shard(dim=1) is expected from FSDP pre-sharding of combined projections and is excluded from the check — it is temporary and undone by FSDP all-gather before the actual matmul.

class nemo_automodel.components.models.common.combined_projection.combined_qkv.CombinedQKVAttentionMixin#

Mixin for combined QKV projection in attention modules.

This mixin ALWAYS uses combined QKV projections for efficiency. Use this with custom transformer attention modules (Llama, Qwen2, etc.).

Usage: class MyAttention(CombinedQKVAttentionMixin, nn.Module): def init(self, config): super().init() # … other init code … self.setup_qkv_projection( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, head_dim=self.head_dim, bias=config.attention_bias )

    def forward(self, hidden_states, ...):
        query_states, key_states, value_states = self.compute_qkv(hidden_states)
        # ... rest of attention logic ...
setup_qkv_projection(
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
bias: bool = False,
use_combined_qkv: bool = True,
)#

Setup combined QKV projection (ALWAYS uses combined format).

Parameters:
  • hidden_size – Model hidden size

  • num_attention_heads – Number of attention heads

  • num_key_value_heads – Number of key/value heads (for GQA)

  • head_dim – Dimension per attention head

  • bias – Whether to use bias in projections

  • use_combined_qkv – DEPRECATED - always True for custom implementations

compute_qkv(
hidden_states: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Compute Q, K, V from hidden states using combined projection.

The QKV weight uses a KV-head-grouped interleaved layout: [Q_group_0 | K_0 | V_0 | Q_group_1 | K_1 | V_1 | …] This ensures ColwiseParallel TP sharding gives each rank complete KV-head groups. We split within each group (a local operation).

Parameters:

hidden_states – Input hidden states [batch, seq_len, hidden_size]

Returns:

Tuple of (query, key, value) tensors, each [batch, seq_len, …]