nemo_automodel.components.models.glm_moe_dsa.kernels.tilelang_indexer_bwd

View as Markdown

Module Contents

Functions

Data

BF16

FP32

INT32

pass_configs

API

nemo_automodel.components.models.glm_moe_dsa.kernels.tilelang_indexer_bwd.indexer_bwd_interface(
index_q: torch.Tensor,
weights: torch.Tensor,
index_k: torch.Tensor,
topk_indices: torch.Tensor,
grad_scores: torch.Tensor
)
nemo_automodel.components.models.glm_moe_dsa.kernels.tilelang_indexer_bwd.tl_indexer_bwd_impl(
heads: int,
dim: int,
topk: int,
block_I: int = 32,
num_stages: int = 0,
num_threads: int = 128
)
nemo_automodel.components.models.glm_moe_dsa.kernels.tilelang_indexer_bwd.BF16 = T.bfloat16
nemo_automodel.components.models.glm_moe_dsa.kernels.tilelang_indexer_bwd.FP32 = T.float32
nemo_automodel.components.models.glm_moe_dsa.kernels.tilelang_indexer_bwd.INT32 = T.int32
nemo_automodel.components.models.glm_moe_dsa.kernels.tilelang_indexer_bwd.pass_configs = {tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tl.PassConfigKey.TL_DISABLE_WARP_S...