Grouped GEMM + dGLU (SM100)#
This is an experimental API and subject to change.
Overview#
Unified Grouped GEMM + dGLU fusion: A block-scaled grouped GEMM fused with a dGLU backward epilogue (dSwiGLU or dGeGLU) on NVIDIA Blackwell GPUs (SM100+), designed for MoE (Mixture of Experts) workloads. Implemented with CUTLASS/CUTE.
This is a unified API that supports both weight layout modes:
Dense mode: All expert weights packed into a single contiguous
(N, K, L)tensorDiscrete mode: Per-expert weight pointers (no weight stacking required)
And both backward activation functions:
dSwiGLU:
act_func="dswiglu"(default)dGeGLU:
act_func="dgeglu"
Groups are contiguous in the M dimension and described by padded_offsets (cumulative aligned end offsets).
This kernel performs:
Block-scaled grouped GEMM: Low-precision GEMM (FP4, FP8) with per-block scale factors across multiple expert groups
dGLU backward epilogue: Fused backward computation using the forward
Ctensor (input/gate interleaved)Optional quantized output: Produces row and column scale factors for downstream quantization
Shapes#
Equations#
Inputs
A: contiguous activation tensor across all groups, shape(valid_m, K, 1)B(dense): weight tensor across all groups, shape(N, K, L)B(discrete): per-expert weight pointers,b_ptrsshape(num_experts,)of int64C: forward intermediate tensor with interleaved input/gate blocks, shape(valid_m, 2N, 1)SFA: scale factor tensor for A, shape(32, 4, ceil(valid_m/128), 4, ceil(ceil(K/sf_vec_size)/4), 1)SFB(dense): scale factor tensor for B, shape(32, 4, ceil(N/128), 4, ceil(ceil(K/sf_vec_size)/4), L)SFB(discrete): per-expert SFB pointers,sfb_ptrsshape(num_experts,)of int64padded_offsets: cumulative sum of aligned group M sizes, shape(L,).valid_m = padded_offsets[-1]alpha: per-group scaling factors for GEMM, shape(L,)beta: per-group scaling factors forC, shape(L,)prob: per-row gating probabilities, shape(valid_m, 1, 1)norm_const: normalization constant for FP8 quantization, shape(1,)
Outputs
D_row: row-quantized dGLU output, shape(valid_m, 2N, 1)D_col: column-quantized dGLU output, shape(valid_m, 2N, 1)dprob: gradient ofprob, shape(valid_m, 1, 1). Must be zero-initialized.dbias(optional): per-expert bias gradient tensor, shape(L, 2N, 1)SFD_row: row scale factors (whend_dtypeis FP8), shape(32, 4, ceil(valid_m/128), 4, ceil(ceil((2N)/sf_vec_size)/4), 1)SFD_col: column scale factors (whend_dtypeis FP8), shape(32, 4, ceil((2N)/128), 4, ceil(ceil(valid_m/sf_vec_size)/4), 1)amax: per-group amax (whend_dtypeis bf16/float16), shape(L, 2, 1)
Step 1: Block-scaled grouped GEMM (per group g with rows m in [padded_offsets[g-1], padded_offsets[g])):
\( \text{ref}[m, n] = \alpha_g^2 \sum_{k} \text{dequantize}(A[m, k], \text{SFA}) \cdot \text{dequantize}(B[n, k, g], \text{SFB}) \)
Step 2: dGLU backward epilogue (performed with 32-column interleaving along 2N):
Scale
Cbybeta_gper group and deinterleave into input/gate halves by 32-wide blocks.
For dSwiGLU (act_func="dswiglu"):
swish = gate * sigmoid(gate)dprob += sum(swish * input * ref)over 32-column chunksab = ref * prob * swishdswiglu = ref * prob * input * sigmoid(gate) * (1 + gate * (1 - sigmoid(gate)))Interleave
[ab, dswiglu]back intoD_row/D_colwith 32-column blocks.
For dGeGLU (act_func="dgeglu"): Uses sigmoid(1.702 * gate) scaling in the backward computation.
Step 3: Optional output quantization (when SFD outputs are generated):
\( \text{SFD_row}[m, n] = \text{norm_const} \cdot \max_{k \in \text{block}} |D[m, k]| \cdot \text{rcp_max} \)
\( D_{\text{quantized}}[m, n] = D[m, n] \cdot \frac{\text{norm_const}}{\text{SFD_row}[m, n]} \)
Diagram#
A (valid_m×K×1) B (N×K×L) or b_ptrs padded_offsets
SFA SFB or sfb_ptrs |
| | |
| +------------+ |
| | |
v v v
Dequantize → Grouped GEMM (per group ranges) → ref
|
| × alpha[group_idx]
v
ref (valid_m×N×1)
|
C (valid_m×2N×1) --× beta[group_idx]--> deinterleave 32-col blocks
| |
| swish, sigmoid
| |
+--> dprob (sum over blocks)
|
+--> ab, dswiglu → interleave → D (valid_m×2N×1)
|
+-----------+-----------+
| |
v v
Row Quantize Col Quantize
| |
v v
D_row, SFD_row D_col, SFD_col
API Usage#
High-level Wrapper#
Dense mode:
from cudnn import grouped_gemm_dglu_wrapper_sm100
from cuda.bindings import driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
outputs = grouped_gemm_dglu_wrapper_sm100(
a_tensor=a,
c_tensor=c,
sfa_tensor=sfa,
padded_offsets=padded_offsets,
alpha_tensor=alpha,
beta_tensor=beta,
prob_tensor=prob,
dprob_tensor=dprob,
generate_dbias=True,
# Dense mode weights:
b_tensor=b,
sfb_tensor=sfb,
# Common:
norm_const_tensor=norm_const,
acc_dtype=torch.float32,
d_dtype=torch.bfloat16,
cd_major="n",
mma_tiler_mn=(256, 256),
cluster_shape_mn=(2, 1),
sf_vec_size=32,
act_func="dswiglu",
epilogue_op=None,
current_stream=stream,
)
# dictionary access:
d_row = outputs["d_row_tensor"]
d_col = outputs["d_col_tensor"]
dprob = outputs["dprob_tensor"]
dbias = outputs["dbias_tensor"]
amax = outputs["amax_tensor"]
sfd_row = outputs["sfd_row_tensor"]
sfd_col = outputs["sfd_col_tensor"]
# or tuple unpacking:
d_row, d_col, dprob, dbias, amax, sfd_row, sfd_col = outputs
Discrete mode:
outputs = grouped_gemm_dglu_wrapper_sm100(
a_tensor=a,
c_tensor=c,
sfa_tensor=sfa,
padded_offsets=padded_offsets,
alpha_tensor=alpha,
beta_tensor=beta,
prob_tensor=prob,
dprob_tensor=dprob,
# Discrete mode weights:
b_ptrs=b_ptrs,
sfb_ptrs=sfb_ptrs,
n=n_dim,
b_dtype=torch.uint8,
b_major="k",
# Common:
act_func="dgeglu",
current_stream=stream,
)
Class API#
Dense mode:
from cudnn import GroupedGemmDgluSm100
from cuda.bindings import driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
api = GroupedGemmDgluSm100(
sample_a=a,
sample_c=c,
sample_d_row=d_row,
sample_d_col=d_col,
sample_sfa=sfa,
sample_padded_offsets=padded_offsets,
sample_alpha=alpha,
sample_beta=beta,
sample_prob=prob,
sample_dprob=dprob,
sample_dbias=dbias,
# Dense mode:
sample_b=b,
sample_sfb=sfb,
# Optional quantization outputs
sample_sfd_row=sfd_row,
sample_sfd_col=sfd_col,
sample_amax=amax,
sample_norm_const=norm_const,
# Configuration
acc_dtype=torch.float32,
mma_tiler_mn=(256, 256),
act_func="dswiglu",
epilogue_op=None,
)
assert api.check_support()
api.compile()
api.execute(
a_tensor=a, c_tensor=c, d_row_tensor=d_row, d_col_tensor=d_col,
sfa_tensor=sfa, padded_offsets=padded_offsets, alpha_tensor=alpha,
beta_tensor=beta, prob_tensor=prob, dprob_tensor=dprob, dbias_tensor=dbias,
b_tensor=b, sfb_tensor=sfb,
sfd_row_tensor=sfd_row, sfd_col_tensor=sfd_col,
amax_tensor=amax, norm_const_tensor=norm_const,
current_stream=stream,
)
In the class API, dbias generation is specialized at compile time: if sample_dbias is omitted, dbias_tensor must also be omitted at execute().
Discrete mode:
api = GroupedGemmDgluSm100(
sample_a=a,
sample_c=c,
sample_d_row=d_row,
sample_d_col=d_col,
sample_sfa=sfa,
sample_padded_offsets=padded_offsets,
sample_alpha=alpha,
sample_beta=beta,
sample_prob=prob,
sample_dprob=dprob,
# Discrete mode:
num_experts=num_experts,
b_shape=(n, k),
b_dtype=torch.uint8,
# Configuration
act_func="dgeglu",
b_major="k",
epilogue_op="relu",
)
assert api.check_support()
api.compile()
api.execute(
a_tensor=a, c_tensor=c, d_row_tensor=d_row, d_col_tensor=d_col,
sfa_tensor=sfa, padded_offsets=padded_offsets, alpha_tensor=alpha,
beta_tensor=beta, prob_tensor=prob, dprob_tensor=dprob,
b_ptrs=b_ptrs, sfb_ptrs=sfb_ptrs,
current_stream=stream,
)
Parameters#
Weight Mode#
The weight mode is auto-detected from constructor arguments:
Dense: Provide
sample_bandsample_sfb(contiguous weight tensors)Discrete: Provide
num_experts,b_shape, andb_dtype(per-expert pointer mode)
Providing both or neither raises ValueError.
Input/Output Tensors#
Input tensor A:
a_tensor/sample_aShape:
(valid_m, K, 1)Stride:
(K, 1, valid_m·K)– must be K-majorDtype (
ab_dtype):{float4_e2m1fn_x2, uint8, float8_e4m3fn, float8_e5m2}
Input tensor B (dense mode):
b_tensor/sample_bShape:
(N, K, L)whereL = num_groupsStride: K-major or N-major. Must be K-major for FP4.
Dtype: Must match A
Input B pointers (discrete mode):
b_ptrsShape:
(num_experts,)– 1-D int64 device tensor of per-expert B data pointers
Input tensor C:
c_tensor/sample_cShape:
(valid_m, 2N, 1)– forward activation with interleaved input/gateStride:
(2N, 1, valid_m·2N)– must be N-majorDtype:
{float32, float16, bfloat16, float8_e4m3fn, float8_e5m2}
Output tensor D_row:
d_row_tensor/sample_d_rowShape:
(valid_m, 2N, 1)Stride:
(2N, 1, valid_m·2N)– must be N-majorDtype (
d_dtype):{bfloat16, float32}for FP4;{float8_e4m3fn, float8_e5m2}for FP8
Output tensor D_col:
d_col_tensor/sample_d_colShape:
(valid_m, 2N, 1)– must match D_row dtype and stride
Input tensor prob:
prob_tensor/sample_probShape:
(valid_m, 1, 1), dtype:float32
Output tensor dprob:
dprob_tensor/sample_dprobShape:
(valid_m, 1, 1), dtype:float32Must be zero-initialized
Scaling tensors:
alphashape(L,),betashape(L,),amaxshape(L, 2, 1),norm_constshape(1,)
Common Parameters#
acc_dtype: Must betorch.float32mma_tiler_mn: Kernel tile size. Default:(256, 256)TILE_M ∈ {128, 256}TILE_N = 256
cluster_shape_mn: Thread Block cluster shape. Default:(2, 1)whenTILE_M=256,(1, 1)otherwisesf_vec_size: Scale factor vector size.{16, 32}. Default:16vector_f32: Enable packed f32 operations. Default:Falsem_aligned: Must be256. Default:256discrete_col_sfd: Generate discrete col-major scale factors. Default:Falseact_func: Backward activation function."dswiglu"(default) or"dgeglu"b_major(discrete only): B tensor major dimension."k"(default) or"n". Must be"k"for FP4.epilogue_op: Optional post-processing.None(default),"identity","relu", or"srelu"
Wrapper-specific Parameters#
d_dtype: Output D tensor data type. Default:torch.bfloat16cd_major: Must be"n". Default:"n"n(discrete only): B weight N dimensionb_dtype(discrete only): B weight data type
Wrapper Return Values#
Returns a TupleDict (dictionary + tuple unpacking):
d_row_tensor: Row-quantized dGLU outputd_col_tensor: Column-quantized dGLU outputdprob_tensor: Gradient of probamax_tensor: Per-group amax (whend_dtypeis bf16/fp16)sfd_row_tensor: Row scale factors (when SFD enabled)sfd_col_tensor: Column scale factors (when SFD enabled)
Support Surface and Constraints#
Layouts and Strides#
Amust be K-majorBmust be K-major (dense) or K/N-major (discrete). Must be K-major for FP4.C,D_row,D_colmust be N-majorAll tensors must be 16-byte aligned along the contiguous dimension
Data Types#
Format |
ab_dtype |
sf_dtype |
sf_vec_size |
d_dtype |
|---|---|---|---|---|
MXFP8 |
|
|
32 |
|
NVF4 |
|
{ |
{16, 32} |
|
Additional Type Constraints#
AandBmust have the same dtypeScale factor tensors must have the same dtype
D_rowandD_colmust have the same dtypedbiasmust bebfloat16sf_dtype=float8_e4m3fnis incompatible withsf_vec_size=32FP8
c_dtypewithvector_f32=Trueis not supportedFP4
ab_dtypeonly supportsd_dtypein{bfloat16, float32}For non-dbias paths, FP4
ab_dtypewithsf_vec_size=16andd_dtype=float32is not supportedFP8
ab_dtypeonly supportsd_dtypein{float8_e4m3fn, float8_e5m2}
Shapes and Divisibility#
Nmust be divisible by 32 (32-column blocks for input/gate interleaving)Expert count must be
<= 1024Each group’s M dimension is aligned to
m_aligned(256)In the class API,
dbiasis compiled in only whensample_dbiasis provided; passing a runtimedbias_tensorwithoutsample_dbiasraisesValueError
Environment#
Requires CUDA with SM100+ compute capability (Blackwell GPUs)
Usage Examples#
For usage examples, see test cases in test/python/fe_api/test_grouped_gemm_dglu.py (dense mode, unified API) and test/python/fe_api/test_discrete_grouped_gemm_dswiglu.py (discrete mode).