nemo_automodel.components.models.deepseek_v32.layers#

DeepSeek V3.2 Layers.

Contains the DeepseekV32Indexer for top-k sparse attention selection and DeepseekV32MLA which integrates the indexer with Multi-head Latent Attention.

Module Contents#

Classes#

DeepseekV32Indexer

Indexer for top-k sparse attention selection.

DeepseekV32MLA

Multi-head Latent Attention with Indexer for sparse attention.

Functions#

_rotate_activation

Apply Hadamard rotation activation.

API#

nemo_automodel.components.models.deepseek_v32.layers._rotate_activation(x: torch.Tensor) torch.Tensor#

Apply Hadamard rotation activation.

Reference: https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/inference/model.py#L424-L428

Parameters:

x – Input tensor (must be bfloat16).

Returns:

Rotated tensor.

class nemo_automodel.components.models.deepseek_v32.layers.DeepseekV32Indexer(
config: nemo_automodel.components.models.deepseek_v32.config.DeepseekV32Config,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: torch.nn.Module

Indexer for top-k sparse attention selection.

Based on the official DeepSeek V3.2 training implementation. Computes attention scores between queries and keys with per-head weights, applies ReLU activation, then selects the top-k positions to attend to.

Key features:

  • Uses LayerNorm (not RMSNorm) for key normalization

  • Has a weights_proj that learns per-head importance weights

  • Optional Hadamard transform (rotate_activation) on Q and K

  • ReLU activation on attention scores before weighting

Initialization

forward(
x: torch.Tensor,
q_resid: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#

Compute top-k indices for sparse attention.

Parameters:
  • x – Hidden states [B, S, hidden] or [T, hidden] for thd format

  • q_resid – Q lora residual from MLA [B, S, q_lora_rank] or [T, q_lora_rank]

  • freqs_cis – RoPE frequencies

  • attention_mask – Optional attention mask

  • **attn_kwargs – Additional attention kwargs (cu_seqlens, etc.)

Returns:

Indices of top-k positions [B, S, topk] or [T, topk]

Return type:

topk_indices

init_weights(init_std: float = 0.02)#
class nemo_automodel.components.models.deepseek_v32.layers.DeepseekV32MLA(
config: nemo_automodel.components.models.deepseek_v32.config.DeepseekV32Config,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: torch.nn.Module

Multi-head Latent Attention with Indexer for sparse attention.

This extends the V3 MLA with an Indexer module that performs top-k selection for sparse attention. The indexer uses the q_lora residual and hidden states to compute which positions to attend to.

Initialization

_build_sparse_mask(
topk_indices: torch.Tensor,
seq_len: int,
qkv_format: str,
bsz: int = 1,
n_heads: int = 1,
dtype: torch.dtype = torch.bfloat16,
attention_mask: torch.Tensor | None = None,
union_across_batches: bool = False,
) torch.Tensor#

Build a sparse attention mask/bias from top-k indices.

Creates a mask tensor where non-top-k positions are set to -inf. Works for both TE (core_attention_bias) and SDPA (attn_mask).

Uses the same efficient pattern as the official DeepSeek inference code: torch.full(..., -inf).scatter_(-1, topk_indices, 0)

Parameters:
  • topk_indices – Indices of top-k positions [B, S, topk] or [T, topk]

  • seq_len – Sequence length

  • qkv_format – β€˜bshd’ or β€˜thd’

  • bsz – Batch size (only used for bshd format)

  • n_heads – Number of attention heads to expand to

  • dtype – Data type for the output tensor

  • attention_mask – Optional attention mask to combine with (for SDPA)

  • union_across_batches – If True, union top-k across batches (for TE); if False, keep per-batch masks (for SDPA)

Returns:

Mask tensor with shape: - [1, n_heads, S, S] if union_across_batches=True - [B, n_heads, S, S] if union_across_batches=False (bshd) - [1, n_heads, T, T] for thd format

Return type:

sparse_mask

forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
)#
init_weights(_buffer_device: torch.device, init_std: float = 0.02)#