nemo_automodel.components.models.glm_moe_dsa.kernels.indexer

View as Markdown

Module Contents

Classes

NameDescription
IndexerFunction-

Functions

API

class nemo_automodel.components.models.glm_moe_dsa.kernels.indexer.IndexerFunction()

Bases: Function

nemo_automodel.components.models.glm_moe_dsa.kernels.indexer.IndexerFunction.backward(
ctx,
grad_scores,
grad_indices
)
staticmethod
nemo_automodel.components.models.glm_moe_dsa.kernels.indexer.IndexerFunction.forward(
ctx,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
topk: int,
topk_indices: torch.Tensor | None = None
)
staticmethod
nemo_automodel.components.models.glm_moe_dsa.kernels.indexer.generate_varlen_mask_params(
cu_seqlens
)
nemo_automodel.components.models.glm_moe_dsa.kernels.indexer.lighting_indexer(
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
topk: int,
topk_indices: torch.Tensor | None = None
)
nemo_automodel.components.models.glm_moe_dsa.kernels.indexer.pytorch_extract_topk_scores(
logits,
topk_indices,
dim = -1
)