nemo_automodel.components.models.glm_moe_dsa.kernels.sparse_mla
nemo_automodel.components.models.glm_moe_dsa.kernels.sparse_mla
Module Contents
Classes
API
Bases: Function
staticmethod
Parameters:
grad_output
Gradient of the loss with respect to output
Returns:
Gradients for q, kv, and indices (None for indices)
staticmethod
Parameters:
q
Query tensor (seq_len, heads, dim_plus_tail_dim)
kv
Key-Value tensor (seq_len_kv, kv_group, dim_plus_tail_dim)
indices
Sparse indices tensor (seq_len, kv_group, topk)
Returns:
Output tensor (seq_len, heads, dim)