MoE Grouped Matmul#
Overview#
The MoE Grouped Matmul operation computes a grouped matrix multiplication across experts, as used in Mixture-of-Experts (MoE) layers. Each expert has its own weight matrix, and tokens are routed to experts via first_token_offset.
Three routing modes are supported:
None mode (tokens already routed per expert):
\(\text{Output}[1,\ S \times \text{topK},\ N] = \text{Token}[1,\ S \times \text{topK},\ K]\ \times\ \text{Weight}[E,\ K,\ N]\)
Gather mode (gather tokens from unrouted layout before matmul):
\(\text{Output}[1,\ S \times \text{topK},\ N] = \text{Token}[1,\ S,\ K]\ \times\ \text{Weight}[E,\ K,\ N]\)
Scatter mode (scatter output back to token order after matmul):
\(\text{Output}[1,\ S \times \text{topK},\ N] = \text{Token}[1,\ S \times \text{topK},\ K]\ \times\ \text{Weight}[E,\ K,\ N]\)
where \(E\) = number of experts, \(S\) = number of tokens, \(K\) = hidden size, \(N\) = output (weight) size.
Tensor Roles by Mode#
Tensor |
Shape |
Modes |
|---|---|---|
|
|
All |
|
|
All |
|
|
All |
|
|
Gather, Scatter |
|
|
Scatter only |
|
scalar int32 |
Scatter only |
Support Matrix#
The support matrix is based on the latest cuDNN backend.
Operation |
Minimum cuDNN |
Minimum cublasLt |
Datatype (I/O) |
Compute type |
Fusion pattern |
|---|---|---|---|---|---|
MoE Grouped Matmul (forward) |
9.18.0 |
— |
int8, fp8, fp16, bf16, fp32, nvfp4/mxfp8 (9.21.0) |
FLOAT |
9.21.0: SwiGLU, AMAX, etc. |
MoE Grouped Matmul Bwd |
9.22.0 |
13.5 |
fp16, bf16 |
FLOAT |
- |
Important Notes#
FirstTokenOffsetcontainsB * Evalues with the total token count implicit from the token tensor dimension.In Scatter mode, both
TokenIndexandTokenKsare required, andtop_kmust be explicitly provided.In Gather mode,
TokenIndexis required.
MoE Grouped Matmul Forward#
C++ API#
std::shared_ptr<Tensor_attributes>
moe_grouped_matmul(std::shared_ptr<Tensor_attributes> token,
std::shared_ptr<Tensor_attributes> weight,
std::shared_ptr<Tensor_attributes> first_token_offset,
std::shared_ptr<Tensor_attributes> token_index, // optional, pass nullptr for None mode
std::shared_ptr<Tensor_attributes> token_ks, // optional, pass nullptr unless Scatter mode
Moe_grouped_matmul_attributes options);
Moe_grouped_matmul_attributes is a lightweight structure with setters:
Moe_grouped_matmul& set_name(std::string const&);
// Required: selects the routing mode
Moe_grouped_matmul& set_mode(MoeGroupedMatmulMode_t mode);
// MoeGroupedMatmulMode_t::NONE — tokens already routed
// MoeGroupedMatmulMode_t::GATHER — gather before matmul
// MoeGroupedMatmulMode_t::SCATTER — scatter after matmul
// Required for SCATTER mode
Moe_grouped_matmul& set_top_k(int32_t top_k_value);
Moe_grouped_matmul& set_compute_data_type(DataType_t value);
Python API#
Low-level graph API (cudnn.pygraph)#
output = graph.moe_grouped_matmul(
token, # Token tensor
weight, # Weight tensor
first_token_offset, # Expert routing offsets
token_index=None, # Required for Gather/Scatter modes
token_ks=None, # Required for Scatter mode
mode=cudnn.moe_grouped_matmul_mode.NONE, # NONE, GATHER, or SCATTER
top_k=1, # Top-k value; required for Scatter mode
compute_data_type=cudnn.data_type.NOT_SET,
name=None,
)
Args:
token(cudnn_tensor): Token data.None/Scatter mode: shape
(1, S*topK, K)Gather mode: shape
(1, S, K)
weight(cudnn_tensor): Expert weight data with shape(E, K, N).first_token_offset(cudnn_tensor): INT32 tensor of shape(B*E, 1, 1). The \(i\)-th entry is the index of the first token assigned to expert \(i\).token_index(Optional[cudnn_tensor]): INT32 tensor of shape(1, S*topK, 1). Maps each routed slot to a source token index. Required for Gather and Scatter modes.token_ks(Optional[cudnn_tensor]): INT32 tensor of shape(1, S*topK, 1). The expert index for each routed token. Required for Scatter mode.mode(cudnn.moe_grouped_matmul_mode): Routing mode —NONE,GATHER, orSCATTER.top_k(int): Top-k routing value. Must be provided for Scatter mode.compute_data_type(Optional[cudnn.data_type]): Data type for internal computation. Defaults to FLOAT.name(Optional[str]): Name for the operation.
Returns:
output(cudnn_tensor): Output tensor of shape(1, M_out, N), whereM_out = token_index.shape[1]for Gather mode, otherwisetoken.shape[1].
High-level experimental API#
from cudnn.experimental.ops import moe_grouped_matmul
output = moe_grouped_matmul(
token, # (1, M, K) torch.Tensor, fp16 or bf16
weight, # (E, K, N) torch.Tensor, column-major inner dims
first_token_offset, # (B*E, 1, 1) torch.Tensor, INT32
token_index=None, # (1, S*topK, 1) INT32; required for gather/scatter
token_ks=None, # (1, S*topK, 1) INT32; required for scatter
mode="none", # "none", "gather", or "scatter"
top_k=1, # required for scatter mode
)
The high-level API handles cuDNN handle management and graph caching automatically. cuDNN graphs are built once per unique (shape, dtype, mode, top_k) configuration and reused across subsequent calls.
Configurable Options#
Mode (
mode): Controls how tokens are routed to and from expert weight matrices.NONE: Tokens are already ordered by expert (pre-routed). Direct grouped matmul with no reordering.GATHER: Tokens are in original (un-routed) order.TokenIndexspecifies which source token each expert slot reads from.SCATTER: Tokens are pre-routed, but the output is scattered back to the original token order. Requires bothTokenIndexandTokenKs.
TopK (
top_k): The number of experts each token is routed to. Required in Scatter mode for the scatter-back computation.Compute data type (
compute_data_type): Sets the precision for internal accumulation. Defaults to FLOAT for fp16/bf16 I/O.
Example (Python)#
import torch
import cudnn
num_experts = 8
token_num = 1024 # S
hidden_size = 512 # K
weight_size = 256 # N
graph = cudnn.pygraph(
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
# Token: [1, S, K], bf16, row-major
token_t = graph.tensor(
name="token",
dim=[1, token_num, hidden_size],
stride=[token_num * hidden_size, hidden_size, 1],
data_type=cudnn.data_type.BFLOAT16,
)
# Weight: [E, K, N], bf16, column-major inner dims (stride[1] == 1)
weight_t = graph.tensor(
name="weight",
dim=[num_experts, hidden_size, weight_size],
stride=[hidden_size * weight_size, 1, hidden_size],
data_type=cudnn.data_type.BFLOAT16,
)
# FirstTokenOffset: [E, 1, 1], INT32
fto_t = graph.tensor(
name="first_token_offset",
dim=[num_experts, 1, 1],
stride=[1, 1, 1],
data_type=cudnn.data_type.INT32,
)
output_t = graph.moe_grouped_matmul(
token_t, weight_t, fto_t,
mode=cudnn.moe_grouped_matmul_mode.NONE,
compute_data_type=cudnn.data_type.FLOAT,
name="moe_fwd",
)
output_t.set_output(True).set_data_type(cudnn.data_type.BFLOAT16)
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A])
graph.check_support()
graph.build_plans()
MoE Grouped Matmul Backward#
The backward operation computes the weight gradient \(d\text{Weight}\) given the upstream gradient \(d\text{Output}\) and the forward token activations:
\(d\text{Weight}[E,\ K,\ N] = \text{Token}^T[1,\ S,\ K]\ \times\ d\text{Output}[1,\ S,\ N]\)
per expert, where the per-expert token slices are determined by FirstTokenOffset.
C++ API#
std::shared_ptr<Tensor_attributes>
moe_grouped_matmul_bwd(std::shared_ptr<Tensor_attributes> doutput,
std::shared_ptr<Tensor_attributes> token,
std::shared_ptr<Tensor_attributes> first_token_offset,
Moe_grouped_matmul_bwd_attributes options);
Moe_grouped_matmul_bwd_attributes is a lightweight structure with setters:
Moe_grouped_matmul_bwd& set_name(std::string const&);
Moe_grouped_matmul_bwd& set_compute_data_type(DataType_t value);
Python API#
Low-level graph API (cudnn.pygraph)#
dweight = graph.moe_grouped_matmul_bwd(
doutput, # Upstream gradient tensor
token, # Forward token activations
first_token_offset, # Expert routing offsets
compute_data_type=cudnn.data_type.NOT_SET,
name=None,
)
Args:
doutput(cudnn_tensor): Upstream gradient with shape(1, S, N), same layout as the forward output.token(cudnn_tensor): Forward token activations with shape(1, S, K).first_token_offset(cudnn_tensor): INT32 tensor of shape(B*E, 1, 1), same as used in the forward pass.compute_data_type(Optional[cudnn.data_type]): Data type for internal accumulation.name(Optional[str]): Name for the operation.
Returns:
dweight(cudnn_tensor): Weight gradient with shape(E, K, N).
Example (Python)#
import torch
import cudnn
num_experts = 8
token_num = 1024
hidden_size = 512
weight_size = 256
graph = cudnn.pygraph(
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
# dOutput: [1, S, N], bf16, row-major
doutput_t = graph.tensor(
name="doutput",
dim=[1, token_num, weight_size],
stride=[token_num * weight_size, weight_size, 1],
data_type=cudnn.data_type.BFLOAT16,
)
# Token: [1, S, K], bf16, row-major
token_t = graph.tensor(
name="token",
dim=[1, token_num, hidden_size],
stride=[token_num * hidden_size, hidden_size, 1],
data_type=cudnn.data_type.BFLOAT16,
)
# FirstTokenOffset: [E, 1, 1], INT32
fto_t = graph.tensor(
name="first_token_offset",
dim=[num_experts, 1, 1],
stride=[1, 1, 1],
data_type=cudnn.data_type.INT32,
)
dweight_t = graph.moe_grouped_matmul_bwd(
doutput_t, token_t, fto_t,
compute_data_type=cudnn.data_type.FLOAT,
name="moe_bwd",
)
# dweight: [E, K, N]
dweight_t.set_output(True).set_data_type(cudnn.data_type.BFLOAT16)
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A])
graph.check_support()
graph.build_plans()