nemo_automodel.components.models.glm_moe_dsa.kernels.sparse_mla

View as Markdown

Module Contents

Classes

NameDescription
SparseMLA-

API

class nemo_automodel.components.models.glm_moe_dsa.kernels.sparse_mla.SparseMLA()

Bases: Function

nemo_automodel.components.models.glm_moe_dsa.kernels.sparse_mla.SparseMLA.backward(
ctx,
grad_output,
grad_lse
)
staticmethod

Parameters:

grad_output

Gradient of the loss with respect to output

Returns:

Gradients for q, kv, and indices (None for indices)

nemo_automodel.components.models.glm_moe_dsa.kernels.sparse_mla.SparseMLA.forward(
ctx,
q,
kv,
indices,
scaling
)
staticmethod

Parameters:

q

Query tensor (seq_len, heads, dim_plus_tail_dim)

kv

Key-Value tensor (seq_len_kv, kv_group, dim_plus_tail_dim)

indices

Sparse indices tensor (seq_len, kv_group, topk)

Returns:

Output tensor (seq_len, heads, dim)