nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd#
Module Contents#
Functions#
Backward interface for a single batch element. |
|
Batched backward: loops over batch dim. |
Data#
API#
- nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.BF16#
None
- nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.FP32#
None
- nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.INT32#
None
- nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.pass_configs#
None
- nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.tl_indexer_bwd_impl(
- heads: int,
- dim: int,
- topk: int,
- block_I: int = 32,
- num_stages: int = 0,
- num_threads: int = 128,
- nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.indexer_bwd_interface(
- index_q: torch.Tensor,
- weights: torch.Tensor,
- index_k: torch.Tensor,
- topk_indices: torch.Tensor,
- grad_scores: torch.Tensor,
Backward interface for a single batch element.
- Parameters:
index_q – [seq_len, heads, dim] bf16
weights – [seq_len, heads] fp32
index_k – [seq_len_kv, dim] bf16
topk_indices – [seq_len, topk] int32
grad_scores – [seq_len, topk] fp32
- Returns:
[seq_len, heads, dim] bf16 grad_w: [seq_len, heads] fp32 grad_k: [seq_len_kv, dim] fp32
- Return type:
grad_q
- nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.batched_indexer_bwd(
- index_q,
- weights,
- index_k,
- topk_indices,
- grad_scores,
Batched backward: loops over batch dim.
- Parameters:
index_q – [seqlen, batch, heads, dim] bf16
weights – [seqlen, batch, heads] fp32
index_k – [seqlen_kv, batch, dim] bf16
topk_indices – [batch, seqlen, topk] int32
grad_scores – [batch, seqlen, topk] fp32
- Returns:
[seqlen, batch, heads, dim] bf16 grad_w: [seqlen, batch, heads] fp32 grad_k: [seqlen_kv, batch, dim] fp32
- Return type:
grad_q