nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd#

Module Contents#

Functions#

tl_indexer_bwd_impl

indexer_bwd_interface

Backward interface for a single batch element.

batched_indexer_bwd

Batched backward: loops over batch dim.

Data#

API#

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.BF16#

None

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.FP32#

None

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.INT32#

None

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.pass_configs#

None

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.tl_indexer_bwd_impl(
heads: int,
dim: int,
topk: int,
block_I: int = 32,
num_stages: int = 0,
num_threads: int = 128,
)#
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.indexer_bwd_interface(
index_q: torch.Tensor,
weights: torch.Tensor,
index_k: torch.Tensor,
topk_indices: torch.Tensor,
grad_scores: torch.Tensor,
)#

Backward interface for a single batch element.

Parameters:
  • index_q – [seq_len, heads, dim] bf16

  • weights – [seq_len, heads] fp32

  • index_k – [seq_len_kv, dim] bf16

  • topk_indices – [seq_len, topk] int32

  • grad_scores – [seq_len, topk] fp32

Returns:

[seq_len, heads, dim] bf16 grad_w: [seq_len, heads] fp32 grad_k: [seq_len_kv, dim] fp32

Return type:

grad_q

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.batched_indexer_bwd(
index_q,
weights,
index_k,
topk_indices,
grad_scores,
)#

Batched backward: loops over batch dim.

Parameters:
  • index_q – [seqlen, batch, heads, dim] bf16

  • weights – [seqlen, batch, heads] fp32

  • index_k – [seqlen_kv, batch, dim] bf16

  • topk_indices – [batch, seqlen, topk] int32

  • grad_scores – [batch, seqlen, topk] fp32

Returns:

[seqlen, batch, heads, dim] bf16 grad_w: [seqlen, batch, heads] fp32 grad_k: [seqlen_kv, batch, dim] fp32

Return type:

grad_q