nemo_automodel.components.models.deepseek_v4.layers

View as Markdown

DeepSeek V4 Attention Layer.

Architecture (from official inference/model.py):

KV path (K = V, single latent): x -> wkv [hidden -> head_dim] # single KV head, K = V = kv -> kv_norm (RMSNorm on head_dim) -> apply_rotary_emb on last rope_head_dim dims K = V = kv (one latent vector serves both key and value)

Output path (grouped): o [bsz, seq, n_heads, head_dim] -> reshape [bsz, seq, n_groups, n_heads_per_group * head_dim] -> wo_a einsum per group: [n_heads_per_group * head_dim] -> [o_lora_rank] -> reshape [bsz, seq, n_groups * o_lora_rank] -> wo_b [n_groups * o_lora_rank -> hidden]

attn_sink: learnable per-head scalar bias added to attention-sink position score.

HC (Hyper-Connections): Each Block maintains hc_mult=4 copies of the hidden state. hc_pre reduces [bsz, seq, hc_mult, dim] -> [bsz, seq, dim] via Sinkhorn mixing. hc_post expands [bsz, seq, dim] -> [bsz, seq, hc_mult, dim]. See DeepseekV4HyperConnection.compute_weights and optimized_kernels.dsv4_sinkhorn_normalize for the torch reference and optional TileKernels Sinkhorn path.

Compress-ratio attention (Compressor + Indexer) is wired into DeepseekV4Attention.forward for layers with compress_ratio > 0. All layers share the same sliding-window causal mask on the local KV path.

Module Contents

Classes

NameDescription
DeepseekV4AttentionSliding-window attention + Compressor + Indexer + attention sink.
DeepseekV4CompressorHF PR 45616 port. Long-range KV branch. Pools compress_ratio tokens
DeepseekV4FP32ParameterCallable holder for fp32 tensors that need their own FSDP unit.
DeepseekV4GroupedLinearBlock-diagonal grouped linear (HF PR 45616 port).
DeepseekV4HyperConnectionPer-site HyperConnection mixer (attention or FFN). Ported from
DeepseekV4HyperHeadFinal HC-stream collapse before the shared RMSNorm + lm_head.
DeepseekV4IndexerHF PR 45616 port. Picks the top-k compressed positions per query when
DeepseekV4RotaryEmbeddingV4 rotary embedding. Produces (cos, sin) sized to qk_rope_head_dim
DeepseekV4TrainCacheTraining-only cache shim mirroring the three methods DeepseekV4Compressor

Functions

NameDescription
_apply_partial_ropeSplit x along its last dim into nope (first) and rope (last
_apply_partial_rope_interleavedInterleaved RoPE on the last rope_head_dim dims of x (pairs are
_build_indexer_topk_compressed_maskBuild the additive compressed-position mask for Indexer-selected pool IDs.
_dsv4_kernel_backendUse TileLang DSV4 kernels only when the attention backend requests them.
_overlap_transformReshape [B, S, ratio, 2*head_dim] -> [B, S, 2*ratio, head_dim] with the
_overlap_transform_with_cp-
_pool_windowsSoftmax-gated sum-pool over ratio consecutive tokens.
_query_positions-
_rms_norm_last_dimRMS-normalize the last dim without materializing an x.square() tensor.
_rope_pool_positions-
_rotate_halfRotate half the hidden dims of the input (Llama / GPT-NeoX style).
_yarn_correction_dim-
_yarn_correction_range-
_yarn_linear_ramp-
apply_rotary_pos_embPort of transformers.models.llama.modeling_llama.apply_rotary_pos_emb.
build_causal_padding_maskBuild a 4D additive causal+padding (+optional sliding-window) mask
build_packed_causal_padding_maskBuild a 4D additive block-causal mask from packed-sequence lengths.
eager_attention_with_sinkEager attention with per-head sink: appends an extra softmax column
repeat_kvPort of transformers.models.llama.modeling_llama.repeat_kv.

API

class nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4Attention(
config: nemo_automodel.components.models.deepseek_v4.config.DeepseekV4Config,
layer_idx: int,
backend: nemo_automodel.components.models.common.BackendConfig | None = None
)

Bases: Module

Sliding-window attention + Compressor + Indexer + attention sink.

Single-head KV (num_key_value_heads=1), grouped low-rank output via :class:DeepseekV4GroupedLinear. compress_ratio == 0 layers skip the compressor / indexer and run pure SWA.

attention_dropout
backend
= backend or BackendConfig()
compress_ratio
compressor
head_dim
= config.head_dim
kv_norm
num_heads
= config.num_attention_heads
num_key_value_groups
= config.num_attention_heads
q_norm
rope_head_dim
= config.qk_rope_head_dim
scaling
= self.head_dim ** -0.5
sinks
Tensor
sinks_param
sliding_window
= int(getattr(config, 'sliding_window', 128) or 128)
wkv
wo_a
wo_b
wq_a
wq_b
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4Attention.forward(
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
position_embeddings_compress: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None,
rotary_compress: torch.nn.Module | None = None,
start_pos: int = 0,
position_ids: torch.Tensor | None = None,
kwargs: typing.Any = {}
) -> tuple[torch.Tensor, torch.Tensor | None]
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4Attention.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4Attention.setup_cp_attention(
cp_mesh
) -> None

Model-owned context-parallel hook, called by moe.parallelizer.apply_cp.

DSV4 runs Miles-style CP (contiguous query shard + all-gathered K/V), so there is no TE DotProductAttention to configure — we just record the CP process group that forward uses to all-gather K/V across CP ranks.

class nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4Compressor(
config: nemo_automodel.components.models.deepseek_v4.config.DeepseekV4Config,
compress_ratio: int,
head_dim: int,
backend: nemo_automodel.components.models.common.BackendConfig | None = None
)

Bases: Module

HF PR 45616 port. Long-range KV branch. Pools compress_ratio tokens into one compressed KV; when ratio == 4 the Indexer narrows the pool.

ape
Tensor
ape_param
backend
= backend or BackendConfig()
indexer
DeepseekV4Indexer | None
kv_norm
overlap
= compress_ratio == 4
rope_head_dim
= config.qk_rope_head_dim
wgate
wkv
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4Compressor._compute_fsdp_group_has_complete_hca_window(
local_has_complete_hca_window: bool,
device: torch.device
) -> bool
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4Compressor._set_hca_param_sync_group(
process_group
) -> None
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4Compressor.forward(
hidden_states: torch.Tensor,
q_residual: torch.Tensor | None,
rotary: torch.nn.Module,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
cache: nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4TrainCache,
layer_idx: int,
start_pos: int,
enable_hca_fsdp_graph_alignment: bool = False,
position_ids: torch.Tensor | None = None,
cp_group = None
) -> torch.Tensor
class nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4FP32Parameter(
value: torch.Tensor
)

Bases: Module

Callable holder for fp32 tensors that need their own FSDP unit.

weight
= nn.Parameter(value.to(torch.float32))
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4FP32Parameter.forward(
reference: torch.Tensor | None = None
) -> torch.Tensor
class nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4GroupedLinear(
in_features_per_group: int,
out_features: int,
n_groups: int,
bias: bool = False
)

Bases: Linear

Block-diagonal grouped linear (HF PR 45616 port).

weight parameter has the standard nn.Linear shape [out_features, in_features_per_group] so quantizers keyed on nn.Linear.weight still find it; forward does per-group bmm.

nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4GroupedLinear.forward(
x: torch.Tensor
) -> torch.Tensor
class nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4HyperConnection(
hc_mult: int,
hidden_size: int,
hc_sinkhorn_iters: int,
hc_eps: float,
rms_norm_eps: float,
sinkhorn_backend: str = 'torch'
)

Bases: Module

Per-site HyperConnection mixer (attention or FFN). Ported from transformers/src/transformers/models/deepseek_v4/modular_deepseek_v4.py class DeepseekV4HyperConnection.

Owns fn (packed linear), base (bias), and scale (scalar per-head gains). compute_weights produces three mixer tensors:

  • pre [B, S, H] : sigmoid-gated collapse weights
  • post [B, S, H] : sigmoid-gated expand weights
  • comb [B, S, H, H] : doubly-stochastic combination matrix from Sinkhorn-normalising sigmoid gates

All math runs in fp32 regardless of the outer cast policy; parameters cast themselves via .float() on each forward. HF lists these params in _keep_in_fp32_modules_strict — the KAutomodel adapter does the same via submodule-name matching.

base
fn
scale
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4HyperConnection.compute_weights(
hidden_streams: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4HyperConnection.forward(
hidden_streams: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]
class nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4HyperHead(
hc_mult: int,
hidden_size: int,
hc_eps: float,
rms_norm_eps: float
)

Bases: Module

Final HC-stream collapse before the shared RMSNorm + lm_head. Ported from modular_deepseek_v4.py class DeepseekV4HyperHead.

Sigmoid-weighted sum over the hc_mult streams (no Sinkhorn). Used once at the end of DeepseekV4Model.forward to go from [B, S, H, D] back to [B, S, D].

hc_base
hc_fn
hc_scale
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4HyperHead.forward(
x: torch.Tensor
) -> torch.Tensor
class nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4Indexer(
config: nemo_automodel.components.models.deepseek_v4.config.DeepseekV4Config,
backend: nemo_automodel.components.models.common.BackendConfig | None = None
)

Bases: Module

HF PR 45616 port. Picks the top-k compressed positions per query when compress_ratio == 4. Owns its own pool at index_head_dim plus a query projection + weights_proj head-mixer.

ape
Tensor
ape_param
backend
= backend or BackendConfig()
compress_ratio
= 4
head_dim
= config.index_head_dim
index_topk
= config.index_topk
kv_norm
n_heads
= config.index_n_heads
rope_head_dim
= config.qk_rope_head_dim
softmax_scale
= self.head_dim ** -0.5
weights_proj
wgate
wkv
wq_b
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4Indexer.forward(
hidden_states: torch.Tensor,
q_residual: torch.Tensor,
rotary: torch.nn.Module,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
cache: nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4TrainCache,
layer_idx: int,
start_pos: int,
position_ids: torch.Tensor | None = None,
cp_group = None
) -> torch.LongTensor
class nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4RotaryEmbedding(
rope_theta: float,
head_dim: int,
partial_rotary_factor: float,
attention_scaling: float = 1.0,
device: torch.device | None = None,
rope_scaling: dict | None = None
)

Bases: Module

V4 rotary embedding. Produces (cos, sin) sized to qk_rope_head_dim (via partial_rotary_factor = qk_rope_head_dim / head_dim), matching HF.

YaRN: when rope_scaling is a YaRN-typed dict ({"type": "yarn", "factor": F, "original_max_position_embeddings": L0, "beta_fast": ..., "beta_slow": ...}), modify inv_freq per dsv4flash/inference/model.py:precompute_freqs_cis — frequency interpolation with a smooth linear ramp between beta_fast/beta_slow correction dims. Used by the compress-rope (theta=160000) on layers with compress_ratio > 0. The main rope (theta=10000, used only on sliding-window layers) gets rope_scaling=None because the reference builds it with original_seq_len=0 for those layers.

inv_freq
Tensor
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4RotaryEmbedding.forward(
x: torch.Tensor,
position_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
class nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4TrainCache()

Training-only cache shim mirroring the three methods DeepseekV4Compressor / DeepseekV4Indexer call on DeepseekV4Cache.

KAutomodel training forward is stateless — we never persist KV or compressor windows across calls. Each DeepseekV4Attention.forward creates a fresh cache instance, which holds per-layer scratch dicts for the duration of the call. When a full window hasn’t accumulated yet we return an empty tensor and let the downstream code handle it.

compressor_state
list[dict] = []
indexer_state
list[dict] = []
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4TrainCache._branch_state(
state_key: str,
layer_idx: int
) -> dict
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4TrainCache.accumulate_windows(
kv: torch.Tensor,
gate: torch.Tensor,
layer_idx: int,
state_key: str,
ratio: int,
start_pos: int
) -> tuple[torch.Tensor, torch.Tensor, int]
nemo_automodel.components.models.deepseek_v4.layers.DeepseekV4TrainCache.update_pool(
new_pooled: torch.Tensor,
layer_idx: int,
state_key: str
) -> torch.Tensor
nemo_automodel.components.models.deepseek_v4.layers._apply_partial_rope(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
rope_head_dim: int
) -> torch.Tensor

Split x along its last dim into nope (first) and rope (last rope_head_dim) slices, rotate only the rope slice with INTERLEAVED pair-RoPE (pairs (2k, 2k+1)), concat back.

The DSV4-Flash released checkpoint uses interleaved RoPE end-to-end (see dsv4flash/inference/model.py:apply_rotary_emb — complex multiplication on view_as_complex of pairs). HF transformers PR 45616 / PR 45643 ship a Llama-style rotate_half here instead, which pairs (d, d+rd/2). Same algebra but a different dim-to-frequency mapping — the released weights expect the interleaved layout, so the Llama-style helper produces wrong activations on the released checkpoint (verified empirically: kv_post_rope cosine drops from 0.9999 to 0.866 after one block under Llama-style; matches at >0.999 under interleaved).

nemo_automodel.components.models.deepseek_v4.layers._apply_partial_rope_interleaved(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
rope_head_dim: int
) -> torch.Tensor

Interleaved RoPE on the last rope_head_dim dims of x (pairs are (2k, 2k+1)). Matches the DeepSeek inference reference’s complex-mul formulation in dsv4flash/inference/model.py:apply_rotary_emb: the released DSV4-Flash weights were trained with this layout, NOT the Llama-style rotate_half layout HF transformers PR 45616/45643 still uses (pairs (d, d+rd/2)).

Inverse rotation: pass -sin instead of sin (caller’s responsibility — same as our existing inverse-rope call site).

Parameters:

x
torch.Tensor

[..., rope_head_dim] (or larger trailing dim with rope on the last rope_head_dim slice). Typical attention-layout shapes: [B, H, S, D] for q/k or [B, 1, S, D] for shared-KV.

cos, sin

shape [B, S, rope_head_dim] produced by the Llama-style cat([freqs, freqs], -1) rotary; we take the first half which contains the unique per-pair frequencies (the second half is a duplicate that the Llama-style helper needs and we don’t).

rope_head_dim
int

Must be even.

nemo_automodel.components.models.deepseek_v4.layers._build_indexer_topk_compressed_mask(
attention_mask: torch.Tensor,
indexer_topk: torch.Tensor,
n_pooled: int
) -> torch.Tensor

Build the additive compressed-position mask for Indexer-selected pool IDs.

nemo_automodel.components.models.deepseek_v4.layers._dsv4_kernel_backend(
backend: nemo_automodel.components.models.common.BackendConfig
) -> str

Use TileLang DSV4 kernels only when the attention backend requests them.

nemo_automodel.components.models.deepseek_v4.layers._overlap_transform(
tensor: torch.Tensor,
head_dim: int,
fill_value: float
) -> torch.Tensor

Reshape [B, S, ratio, 2*head_dim] -> [B, S, 2*ratio, head_dim] with the cross-window overlap from the DeepSeek inference reference (Compressor.overlap_transform in dsv4flash/inference/model.py:307-314).

Window 0 has no previous block, so its [:ratio] slice is left at fill_value (0 for the kv tensor, -inf for the score tensor so softmax masks it out).

nemo_automodel.components.models.deepseek_v4.layers._overlap_transform_with_cp(
tensor: torch.Tensor,
head_dim: int,
fill_value: float,
cp_group
) -> torch.Tensor
nemo_automodel.components.models.deepseek_v4.layers._pool_windows(
kv: torch.Tensor,
gate: torch.Tensor,
ape: torch.Tensor,
ratio: int,
head_dim: int,
overlap: bool = False,
cp_group = None
) -> torch.Tensor

Softmax-gated sum-pool over ratio consecutive tokens.

Non-overlap mode (HF PR 45616 layout, ratio==128 in V4-Flash): Input kv/gate of shape [B, length, head_dim]. Reshape to [B, length/ratio, ratio, head_dim] and pool over the ratio axis.

Overlap mode (DeepSeek inference reference layout, ratio==4 in V4-Flash): Input kv/gate of shape [B, length, 2*head_dim] (wkv/wgate project to 2*head_dim so each window can carry both its own kv and a half-overlap into the next window). Reshape to [B, length/ratio, ratio, 2*head_dim], apply :func:_overlap_transform to remap to [B, length/ratio, 2*ratio, head_dim], then pool over the 2*ratio axis. Each compressed token thus aggregates 2*ratio = 8 raw tokens — the ratio tokens of the current window plus the ratio tokens of the previous window — giving smoother compression boundaries that the released checkpoint was trained under.

HF PR 45616 omits the overlap path entirely; the released DSV4-Flash safetensors have ape/wkv/wgate shapes that only match the overlap layout ([ratio, 2*head_dim] and [2*head_dim, hidden]), so we must support it here to load the released weights.

nemo_automodel.components.models.deepseek_v4.layers._query_positions(
position_ids: torch.Tensor | None,
batch: int,
seq_len: int,
device: torch.device,
cp_group = None
) -> torch.Tensor
nemo_automodel.components.models.deepseek_v4.layers._rms_norm_last_dim(
x: torch.Tensor,
eps: float
) -> torch.Tensor

RMS-normalize the last dim without materializing an x.square() tensor.

nemo_automodel.components.models.deepseek_v4.layers._rope_pool_positions(
pool_length: int,
pool_base: int,
ratio: int,
device: torch.device,
batch: int
) -> torch.Tensor
nemo_automodel.components.models.deepseek_v4.layers._rotate_half(
x: torch.Tensor
) -> torch.Tensor

Rotate half the hidden dims of the input (Llama / GPT-NeoX style).

nemo_automodel.components.models.deepseek_v4.layers._yarn_correction_dim(
num_rotations: float,
dim: int,
base: float,
max_seq_len: int
) -> float
nemo_automodel.components.models.deepseek_v4.layers._yarn_correction_range(
low_rot: float,
high_rot: float,
dim: int,
base: float,
max_seq_len: int
) -> tuple[int, int]
nemo_automodel.components.models.deepseek_v4.layers._yarn_linear_ramp(
min_v: float,
max_v: float,
dim: int,
device = None
) -> torch.Tensor
nemo_automodel.components.models.deepseek_v4.layers.apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.Tensor | None = None,
unsqueeze_dim: int = 1
) -> tuple[torch.Tensor, torch.Tensor]

Port of transformers.models.llama.modeling_llama.apply_rotary_pos_emb.

nemo_automodel.components.models.deepseek_v4.layers.build_causal_padding_mask(
attention_mask: torch.Tensor | None,
seq_len: int,
dtype: torch.dtype,
device: torch.device,
batch_size: int = 1,
sliding_window: int | None = None
) -> torch.Tensor | None

Build a 4D additive causal+padding (+optional sliding-window) mask compatible with eager_attention_with_sink.

Mirrors HF’s create_sliding_window_causal_mask (used in DeepseekV4Model.forward): each query at position i attends only to keys at positions [max(0, i - sliding_window + 1), i]. The DSV4-Flash weights were trained with this banding on every layer, so dropping it makes the softmax see a different distribution than training and degrades loss.

Returns: [B, 1, S, S] additive mask of dtype (0 where keep, large negative where mask).

nemo_automodel.components.models.deepseek_v4.layers.build_packed_causal_padding_mask(
seq_lens: torch.Tensor,
seq_len: int,
dtype: torch.dtype,
device: torch.device,
sliding_window: int | None = None
) -> torch.Tensor

Build a 4D additive block-causal mask from packed-sequence lengths.

nemo_automodel.components.models.deepseek_v4.layers.eager_attention_with_sink(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
kwargs = {}
) -> tuple[torch.Tensor, torch.Tensor]

Eager attention with per-head sink: appends an extra softmax column whose logit is module.sinks[h] and whose value-slot is zero. Ported verbatim from HF PR 45616.

nemo_automodel.components.models.deepseek_v4.layers.repeat_kv(
hidden_states: torch.Tensor,
n_rep: int
) -> torch.Tensor

Port of transformers.models.llama.modeling_llama.repeat_kv.