Grouped GEMM + dSwiGLU (SM100)#
This is an experimental API and subject to change.
Overview#
Grouped GEMM + dSwiGLU fusion: A contiguous grouped block-scaled GEMM fused with a dSwiGLU backward epilogue on NVIDIA Blackwell GPUs (SM100+), designed for MoE (Mixture of Experts) workloads. Implemented with CUTLASS/CUTE.
Groups are contiguous in the M dimension and described by padded_offsets (cumulative aligned end offsets).
This kernel performs:
Block-scaled grouped GEMM: Low-precision GEMM (FP4, FP8) with per-block scale factors across multiple expert groups
dSwiGLU backward epilogue: Fused backward computation using the forward
Ctensor (input/gate interleaved)Optional quantized output: Produces row and column scale factors for downstream quantization
Shapes#
Equations#
Inputs
A: contiguous activation tensor across all groups, shape(valid_m, K, 1)B: weight tensor across all groups, shape(N, K, L)C: forward intermediate tensor with interleaved input/gate blocks, shape(valid_m, 2N, 1)SFA: scale factor tensor for A, shape(32, 4, ceil(valid_m/128), 4, ceil(ceil(K/sf_vec_size)/4), 1)SFB: scale factor tensor for B, shape(32, 4, ceil(N/128), 4, ceil(ceil(K/sf_vec_size)/4), L)padded_offsets: cumulative sum of aligned group M sizes, shape(L,).valid_m = padded_offsets[-1]alpha: per-group scaling factors for GEMM, shape(L,)beta: per-group scaling factors forC, shape(L,)prob: per-row gating probabilities, shape(valid_m, 1, 1)norm_const: normalization constant for FP8 quantization, shape(1,)
Outputs
D_row: row-quantized dSwiGLU output, shape(valid_m, 2N, 1)D_col: column-quantized dSwiGLU output, shape(valid_m, 2N, 1)dprob: gradient ofprob, shape(valid_m, 1, 1). Must be zero-initialized.SFD_row: row scale factors (whend_dtypeis FP8), shape(32, 4, ceil(valid_m/128), 4, ceil(ceil((2N)/sf_vec_size)/4), 1)SFD_col: column scale factors (whend_dtypeis FP8), shape(32, 4, ceil((2N)/128), 4, ceil(ceil(valid_m/sf_vec_size)/4), 1)amax: per-group amax (whend_dtypeis bf16/float16), shape(L, 2, 1)
Step 1: Block-scaled grouped GEMM (per group g with rows m in [padded_offsets[g-1], padded_offsets[g])):
\( \text{ref}[m, n] = \alpha_g^2 \sum_{k} \text{dequantize}(A[m, k], \text{SFA}) \cdot \text{dequantize}(B[n, k, g], \text{SFB}) \)
Step 2: dSwiGLU backward epilogue (performed with 32-column interleaving along 2N):
Scale
Cbybeta_gper group and deinterleave into input/gate halves by 32-wide blocks.swish = gate * sigmoid(gate)dprobis the sum over 32-column chunks ofswish * input * refab = ref * prob * swishdswiglu = ref * prob * input * sigmoid(gate) * (1 + gate * (1 - sigmoid(gate)))Interleave
[ab, dswiglu]back intoD_row/D_colwith 32-column blocks.
Step 3: Optional output quantization (when SFD outputs are generated):
\( \text{SFD_row}[m, n] = \text{norm_const} \cdot \max_{k \in \text{block}} |D[m, k]| \cdot \text{rcp_max} \)
\( D_{\text{quantized}}[m, n] = D[m, n] \cdot \frac{\text{norm_const}}{\text{SFD_row}[m, n]} \)
Diagram#
A (valid_m×K×1) B (N×K×L) padded_offsets
SFA SFB |
| | |
| +------------+ |
| | |
v v v
Dequantize → Grouped GEMM (per group ranges) → ref
|
| × alpha[group_idx]
v
ref (valid_m×N×1)
|
C (valid_m×2N×1) --× beta[group_idx]--> deinterleave 32-col blocks
| |
| swish, sigmoid
| |
+--> dprob (sum over blocks)
|
+--> ab, dswiglu → interleave → D (valid_m×2N×1)
|
+-----------+-----------+
| |
v v
Row Quantize Col Quantize
| |
v v
D_row, SFD_row D_col, SFD_col
API Usage#
High-level Wrapper#
from cudnn import grouped_gemm_dswiglu_wrapper_sm100
from cuda.bindings import driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
outputs = grouped_gemm_dswiglu_wrapper_sm100(
a_tensor=a,
b_tensor=b,
c_tensor=c,
sfa_tensor=sfa,
sfb_tensor=sfb,
padded_offsets=padded_offsets,
alpha_tensor=alpha,
beta_tensor=beta,
prob_tensor=prob,
norm_const_tensor=norm_const, # Required when SFD outputs are enabled (FP8 inputs)
acc_dtype=torch.float32,
d_dtype=torch.bfloat16,
cd_major="n",
mma_tiler_mn=(256, 256),
cluster_shape_mn=(2, 1),
sf_vec_size=32,
vector_f32=False,
m_aligned=256,
discrete_col_sfd=False,
epilogue_op=None,
current_stream=stream,
)
# dictionary access:
d_row = outputs["d_row_tensor"] # row-quantized dSwiGLU output
d_col = outputs["d_col_tensor"] # column-quantized dSwiGLU output
dprob = outputs["dprob_tensor"] # dprob output
amax = outputs["amax_tensor"] # per-group amax (when d_dtype is bf16)
sfd_row = outputs["sfd_row_tensor"] # row scale factors (when d_dtype is FP8)
sfd_col = outputs["sfd_col_tensor"] # column scale factors (when d_dtype is FP8)
# or tuple unpacking:
d_row, d_col, dprob, amax, sfd_row, sfd_col = outputs
Class API#
from cudnn import GroupedGemmDswigluSm100
from cuda.bindings import driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
api = GroupedGemmDswigluSm100(
sample_a=a,
sample_b=b,
sample_c=c,
sample_d_row=d_row,
sample_d_col=d_col,
sample_sfa=sfa,
sample_sfb=sfb,
sample_padded_offsets=padded_offsets,
sample_alpha=alpha,
sample_beta=beta,
sample_prob=prob,
sample_dprob=dprob,
# Optional quantization outputs
sample_sfd_row=sfd_row, # Required when SFD outputs are enabled
sample_sfd_col=sfd_col, # Required when SFD outputs are enabled
sample_amax=amax, # Required for bf16 output with FP4 input
sample_norm_const=norm_const, # Required when SFD outputs are enabled
# Configuration
acc_dtype=torch.float32,
mma_tiler_mn=(256, 256),
cluster_shape_mn=(2, 1),
sf_vec_size=32,
vector_f32=False,
m_aligned=256,
discrete_col_sfd=False,
epilogue_op=None,
)
assert api.check_support()
api.compile()
api.execute(
a_tensor=a,
b_tensor=b,
c_tensor=c,
d_row_tensor=d_row,
d_col_tensor=d_col,
sfa_tensor=sfa,
sfb_tensor=sfb,
padded_offsets=padded_offsets,
alpha_tensor=alpha,
beta_tensor=beta,
prob_tensor=prob,
dprob_tensor=dprob,
sfd_row_tensor=sfd_row,
sfd_col_tensor=sfd_col,
amax_tensor=amax,
norm_const_tensor=norm_const,
current_stream=stream,
)
Parameters#
Input/Output Tensors#
Input tensor A:
a_tensor(wrapper) orsample_a,a_tensor(class)Shape:
(valid_m, K, 1)Stride:
(K, 1, valid_m·K)- must be K-majorDtype (
ab_dtype):{float4_e2m1fn_x2, uint8, float8_e4m3fn, float8_e5m2}uint8is interpreted as packed FP4 (two FP4 values per byte)
Input tensor B:
b_tensor(wrapper) orsample_b,b_tensor(class)Shape:
(N, K, L)whereL = num_groupsStride:
(K, 1, N·K)(K-major) or(1, N, N·K)(N-major). Must be K-major for fp4 inputs.Dtype (
ab_dtype): Must match A
Input tensor C:
c_tensor(wrapper) orsample_c,c_tensor(class)Shape:
(valid_m, 2N, 1)Stride:
(2N, 1, valid_m·2N)- must be N-majorDtype (
c_dtype):{float32, float16, bfloat16, float8_e4m3fn, float8_e5m2}
Output tensor D_row:
d_row_tensor(class) or returned in wrapper dictShape:
(valid_m, 2N, 1)Stride:
(2N, 1, valid_m·2N)- must be N-majorDtype (
d_dtype):{bfloat16, float32}for FP4 inputs;{float8_e4m3fn, float8_e5m2}for FP8 inputs
Output tensor D_col:
d_col_tensor(class) or returned in wrapper dictShape:
(valid_m, 2N, 1)Stride:
(2N, 1, valid_m·2N)- must match D_row (N-major)Dtype: Must match D_row
Input tensor prob:
prob_tensor(wrapper) orsample_prob(class)Shape:
(valid_m, 1, 1)Dtype:
float32
Output tensor dprob:
dprob_tensor(wrapper) orsample_dprob(class)Shape:
(valid_m, 1, 1)Dtype:
float32Must be zero-initialized.
Scale factor tensors
SFA (A scale factor):
sfa_tensor(wrapper) orsample_sfa,sfa_tensor(class)Shape:
(32, 4, ceil(valid_m/128), 4, ceil(ceil(K/sf_vec_size)/4), 1)Dtype (
sf_dtype):{float8_e8m0fnu, float8_e4m3fn}
SFB (B scale factor):
sfb_tensor(wrapper) orsample_sfb,sfb_tensor(class)Shape:
(32, 4, ceil(N/128), 4, ceil(ceil(K/sf_vec_size)/4), L)Dtype: Must match SFA
SFD_row (D row scale factor, optional):
sfd_row_tensor(wrapper) orsample_sfd_row,sfd_row_tensor(class)Shape:
(32, 4, ceil(valid_m/128), 4, ceil(ceil((2N)/sf_vec_size)/4), 1)Dtype: Must match SFA
Required when: SFD outputs are enabled (FP8 inputs)
SFD_col (D column scale factor, optional):
sfd_col_tensor(wrapper) orsample_sfd_col,sfd_col_tensor(class)Shape:
(32, 4, ceil((2N)/128), 4, ceil(ceil(valid_m/sf_vec_size)/4), 1)Dtype: Must match SFA
Required when: SFD outputs are enabled (FP8 inputs)
Group offsets
padded_offsets: Cumulative sum of aligned group M sizes
Shape:
(L,)whereL = num_groupsDtype:
int32padded_offsets[-1]equalsvalid_m; each offset is a multiple ofm_aligned
Scaling tensors
alpha: Per-group scaling factors
Shape:
(L,)whereL = num_groupsDtype:
float32
beta: Per-group scaling factors for
CShape:
(L,)whereL = num_groupsDtype:
float32
amax (optional): Per-group max absolute values
Shape:
(L, 2, 1)Dtype:
float32Required when:
d_dtype ∈ {bfloat16, float16}
norm_const (optional): Normalization constant for FP8 quantization
Shape:
(1,)Dtype:
float32Required when:
sfd_row_tensor/sfd_col_tensorare provided (FP8 inputs)
Common Parameters#
acc_dtype: torch.dtypeAccumulator dtype. Must be
torch.float32
mma_tiler_mn: Tuple[int, int]Kernel tile size
(TILE_M, TILE_N). Default:(256, 256)TILE_M ∈ {64, 128, 256}TILE_N ∈ {128, 256}
cluster_shape_mn: Tuple[int, int] | NoneThread Block cluster shape
(CLUSTER_M, CLUSTER_N)Constraints: positive powers of 2, both <= 4,
CLUSTER_M × CLUSTER_N <= 16Default:
(2, 1)whenTILE_M=256,(1, 1)otherwise
sf_vec_size: intScale factor vector size (number of elements per scale factor)
Allowed values:
{16, 32}. Default:16
vector_f32: boolEnable packed f32 operations for improved performance
Default:
False
m_aligned: intAlignment requirement for group M dimension
Must equal
FIX_PAD_SIZE(256) and be divisible bymma_tiler_mn[0]Default:
256
discrete_col_sfd: boolIf True, generate discrete column scale factors grouped by expert tiles
Only applies when
sfd_row_tensor,sfd_col_tensor, andnorm_const_tensorare providedNo extra inputs are required; this only changes the layout of
sfd_col_tensorDefault:
False
epilogue_op: Optional[str]Optional epilogue operation. Valid values:
None,"none","identity","relu","srelu"Default:
None
CUDA stream (
current_streamin class API,current_streamin wrapper)
Wrapper-specific Parameters: grouped_gemm_dswiglu_wrapper_sm100#
d_dtype: torch.dtype: Output D tensor data type. Default:torch.bfloat16cd_major: str: Major dimension for C and D tensors. Must be"n"(only N-major layout is supported). Default:"n"
Wrapper Return Values#
Returns a TupleDict - a dictionary-like object that also supports tuple unpacking and integer indexing.
Dictionary keys (also the tuple unpacking order):
d_row_tensor: Row-quantized dSwiGLU outputd_col_tensor: Column-quantized dSwiGLU outputdprob_tensor: Gradient ofprobamax_tensor: Per-group amax (whend_dtype ∈ {bfloat16, float16})sfd_row_tensor: Row scale factors (when SFD outputs are enabled)sfd_col_tensor: Column scale factors (when SFD outputs are enabled)
Class-specific Parameters#
GroupedGemmDswigluSm100 (constructor)#
sample_a,sample_b,sample_c,sample_d_row,sample_d_col,sample_sfa,sample_sfb,sample_padded_offsets,sample_alpha,sample_beta,sample_prob,sample_dprob,sample_sfd_row,sample_sfd_col,sample_amax,sample_norm_const- see Input/Output tensorsNote:
sample_sfd_row,sample_sfd_col,sample_norm_constmust be allNoneor all notNone
GroupedGemmDswigluSm100.execute#
a_tensor,b_tensor,c_tensor,d_row_tensor,d_col_tensor,sfa_tensor,sfb_tensor,padded_offsets,alpha_tensor,beta_tensor,prob_tensor,dprob_tensor,sfd_row_tensor,sfd_col_tensor,amax_tensor,norm_const_tensor- see Input/Output tensors. Must have same layout as sample tensors provided in constructor.
Support Surface and Constraints#
Layouts and Strides#
Amust be K-major (contiguous along K dimension)Bmust be K-major (contiguous along K dimension) or N-major (contiguous along N dimension). Must be K-major for fp4 inputs.C,D_row, andD_colmust be N-major (contiguous along N dimension)All tensors must be 16-byte aligned along the contiguous dimension
Data Types#
Input/Weight Types (ab_dtype)#
Format |
ab_dtype |
sf_dtype |
sf_vec_size |
d_dtype |
|---|---|---|---|---|
MXFP8 |
|
|
32 |
|
NVF4 |
|
{ |
{16, 32} |
|
Additional Type Constraints#
AandBmust have the same dtypeSFA,SFB,SFD_row, andSFD_colmust have the same dtypeD_rowandD_colmust have the same dtypeacc_dtypemust befloat32sf_dtype=float8_e4m3fnis incompatible withsf_vec_size=32FP8
c_dtypewithvector_f32=Trueis not supportedFP4
ab_dtypeonly supportsd_dtype ∈ {bfloat16, float32}FP8
ab_dtypeonly supportsd_dtype ∈ {float8_e4m3fn, float8_e5m2}
Scale Factor Output Requirements#
When
sfd_row_tensor/sfd_col_tensorare provided (FP8 inputs):sfd_row_tensor,sfd_col_tensor, andnorm_const_tensorare all requiredThese must be provided together (all None or all not None)
When
d_dtype ∈ {bfloat16, float16}:amax_tensoris required for tracking per-group max values
Tiling and Cluster#
mma_tiler_mn[0] = 256enables 2-CTA instructions automatically (use_2cta_instrs=True)When
use_2cta_instrs=True:cluster_shape_mn[0]must be divisible by 2m_alignedmust be divisible bymma_tiler_mn[0]to prevent tiles from spanning multiple groups
Shapes and Divisibility#
Nmust be divisible by 32 (32-column blocks for input/gate interleaving)padded_offsetslengthLis the expert count and must be<= 1024Each group’s M dimension is aligned to
m_alignedvalid_m = padded_offsets[-1]determines the actual tensor M dimensionScale factor tensor shapes follow the MMA atom tiling pattern:
(32, 4, ceil(dim/128), 4, ceil(K_groups/4), L)
Environment#
Requires CUDA with SM100+ compute capability (Blackwell GPUs)
Usage Examples#
For usage examples, see test cases in test/python/fe_api/test_grouped_gemm_dswiglu.py + test/python/fe_api/test_grouped_gemm_dswiglu_utils.py