nemo_automodel.components.models.deepseek_v4.kernels.sparse_attention

View as Markdown

Autograd wrapper for vendored Miles DeepSeek V4 sparse-attention kernels.

Attribution:

Module Contents

Classes

NameDescription
DeepSeekV4SparseAttentionTileLang sparse MQA attention with custom backward.
DeepSeekV4SparseAttentionHeadChunkedTileLang sparse attention with smaller head groups and fp32 KV-grad accumulation.

Functions

NameDescription
sparse_attn_tilelangRun vendored Miles DeepSeek V4 TileLang sparse attention.
sparse_attn_tilelang_head_chunkedRun vendored Miles sparse attention in TileLang head chunks.

API

class nemo_automodel.components.models.deepseek_v4.kernels.sparse_attention.DeepSeekV4SparseAttention()

Bases: Function

TileLang sparse MQA attention with custom backward.

nemo_automodel.components.models.deepseek_v4.kernels.sparse_attention.DeepSeekV4SparseAttention.backward(
ctx,
grad_output: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]
staticmethod

Run the vendored sparse attention backward kernel.

nemo_automodel.components.models.deepseek_v4.kernels.sparse_attention.DeepSeekV4SparseAttention.forward(
ctx,
q: torch.Tensor,
kv: torch.Tensor,
attn_sink: torch.Tensor,
topk_idxs: torch.Tensor,
sm_scale: float | None = None
) -> torch.Tensor
staticmethod

Run the vendored sparse attention forward kernel.

class nemo_automodel.components.models.deepseek_v4.kernels.sparse_attention.DeepSeekV4SparseAttentionHeadChunked()

Bases: Function

TileLang sparse attention with smaller head groups and fp32 KV-grad accumulation.

nemo_automodel.components.models.deepseek_v4.kernels.sparse_attention.DeepSeekV4SparseAttentionHeadChunked.backward(
ctx,
grad_output: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None]
staticmethod

Run chunked backward and accumulate shared KV gradients in fp32.

nemo_automodel.components.models.deepseek_v4.kernels.sparse_attention.DeepSeekV4SparseAttentionHeadChunked.forward(
ctx,
q: torch.Tensor,
kv: torch.Tensor,
attn_sink: torch.Tensor,
topk_idxs: torch.Tensor,
max_heads_per_kernel: int,
sm_scale: float | None = None
) -> torch.Tensor
staticmethod

Run the vendored sparse attention forward kernel over head chunks.

nemo_automodel.components.models.deepseek_v4.kernels.sparse_attention.sparse_attn_tilelang(
q: torch.Tensor,
kv: torch.Tensor,
attn_sink: torch.Tensor,
topk_idxs: torch.Tensor,
sm_scale: float | None = None
) -> torch.Tensor

Run vendored Miles DeepSeek V4 TileLang sparse attention.

nemo_automodel.components.models.deepseek_v4.kernels.sparse_attention.sparse_attn_tilelang_head_chunked(
q: torch.Tensor,
kv: torch.Tensor,
attn_sink: torch.Tensor,
topk_idxs: torch.Tensor,
max_heads_per_kernel: int,
sm_scale: float | None = None
) -> torch.Tensor

Run vendored Miles sparse attention in TileLang head chunks.