Discrete Grouped GEMM + dSwiGLU (SM100)#
This is an experimental API and subject to change.
Overview#
Discrete Grouped GEMM + dGLU backward fusion: A block-scaled grouped GEMM fused with a dSwiGLU/dGeGLU backward epilogue on NVIDIA Blackwell GPUs (SM100+), designed for MoE workloads where each expert weight/scale lives in a separate allocation.
This API uses per-expert pointer tensors instead of packed (N, K, L) tensors:
b_ptrs: deviceint64tensor of per-expert B pointerssfb_ptrs: deviceint64tensor of per-expert SFB pointers
Groups are contiguous in the M dimension and described by padded_offsets (cumulative aligned end offsets).
This kernel performs:
Block-scaled grouped GEMM backward core using per-expert pointer inputs
dGLU backward epilogue (
act_func in {"dswiglu", "dgeglu"}) usingc_tensor,beta,prob, anddprobOptional quantized output with row/column scale factors
Shapes#
Inputs
A: contiguous activation/gradient tensor across all groups, shape(valid_m, K, 1)B_g: expert-gweight tensor referenced byb_ptrs[g], logical shape(N, K)(or(N, K, 1))C: forward activation tensor consumed by backward epilogue, shape(valid_m, N/2, 1)SFA: scale factor tensor for A, shape(32, 4, ceil(valid_m/128), 4, ceil(ceil(K/sf_vec_size)/4), 1)SFB_g: expert-gB scale tensor referenced bysfb_ptrs[g], shape(32, 4, ceil(N/128), 4, ceil(ceil(K/sf_vec_size)/4), 1)padded_offsets: cumulative sum of aligned group M sizes, shape(L,).valid_m = padded_offsets[-1]alpha: per-group scaling factors, shape(L,)beta: per-group scaling factors forC, shape(L,)prob: per-row gating probabilities, shape(valid_m, 1, 1)dprob: probability gradient output buffer, shape(valid_m, 1, 1). Must be zero-initialized.norm_const: normalization constant for FP8 quantization, shape(1,)
Outputs
D_row: row-quantized dGLU output, shape(valid_m, N, 1)D_col: column-quantized dGLU output, shape(valid_m, N, 1)dprob: updated in-place probability gradient output, shape(valid_m, 1, 1)SFD_row: row scale factors (when SFD outputs are enabled; wrapper auto-enables this for FP8-input configs), shape(32, 4, ceil(valid_m/128), 4, ceil(ceil(N/sf_vec_size)/4), 1)SFD_col: column scale factors (when SFD outputs are enabled; wrapper auto-enables this for FP8-input configs), shape(32, 4, ceil(N/128), 4, ceil(ceil(valid_m/sf_vec_size)/4), 1)amax: per-group amax (optional; wrapper provides it whend_dtypeis bf16/fp16), shape(L, 2, 1)
Equations#
Step 1: Block-scaled grouped GEMM (per group g with rows m in [padded_offsets[g-1], padded_offsets[g])):
\( \text{ref}[m, n] = \alpha_g^2 \sum_{k} \text{dequantize}(A[m, k], \text{SFA}) \cdot \text{dequantize}(B_g[n, k], \text{SFB}_g) \)
Step 2: dGLU backward epilogue (equations shown for act_func="dswiglu"):
For each epilogue tile, the kernel reads paired activation fragments from C and forms gate/input branches:
gate = beta_g * C_gateup = beta_g * C_upsig = sigmoid(gate)swish = gate * sig
Then:
\( dprob \mathrel{+}= \text{reduce}_{32}\left(\text{ref} \cdot up \cdot swish\right) \)
\( d_{\text{gate}} = \text{ref} \cdot prob \cdot up \cdot sig \cdot \left(1 + gate \cdot (1 - sig)\right) \)
\( d_{\text{up}} = \text{ref} \cdot prob \cdot swish \)
[d_gate, d_up] are interleaved back into full-width D_row/D_col in 32-column blocks.
For act_func="dgeglu", the derivative math switches to dGeGLU.
Step 3: Optional output quantization (when SFD outputs are generated):
\( \text{SFD_row}[m, n] = \text{norm_const} \cdot \max_{k \in \text{block}} |D[m, k]| \cdot \text{rcp_max} \)
\( D_{\text{quantized}}[m, n] = D[m, n] \cdot \frac{\text{norm_const}}{\text{SFD_row}[m, n]} \)
Diagram#
A (valid_m×K×1) per-expert B/SFB pointers padded_offsets
SFA b_ptrs, sfb_ptrs |
| | |
| +--------------------------+ |
| | |
v v v
Dequantize → Grouped GEMM (expert selected by row range) → ref
|
C (valid_m×N/2×1) --× beta[group_idx]--> paired C fragments
|
+--> dGLU backward (act_func in {dswiglu, dgeglu})
| + prob
| + dprob accumulation
v
D_row, D_col (valid_m×N×1)
|
+------------+-------------+
| |
v v
Row Quantize Col Quantize
| |
v v
SFD_row SFD_col
API Usage#
High-level Wrapper#
from cudnn import discrete_grouped_gemm_dswiglu_wrapper_sm100
from cuda.bindings import driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
b_ptrs = torch.tensor([b.data_ptr() for b in b_list], dtype=torch.int64, device="cuda")
sfb_ptrs = torch.tensor([sfb.data_ptr() for sfb in sfb_list], dtype=torch.int64, device="cuda")
dprob = torch.zeros((valid_m, 1, 1), dtype=torch.float32, device="cuda")
outputs = discrete_grouped_gemm_dswiglu_wrapper_sm100(
a_tensor=a_tensor,
b_ptrs=b_ptrs,
c_tensor=c_tensor, # shape (valid_m, N/2, 1)
sfa_tensor=sfa_tensor,
sfb_ptrs=sfb_ptrs,
padded_offsets=padded_offsets,
alpha_tensor=alpha_tensor,
beta_tensor=beta_tensor,
prob_tensor=prob_tensor,
dprob_tensor=dprob, # output buffer; must be zero-initialized
n=n, # logical full N
b_dtype=b_dtype,
norm_const_tensor=norm_const, # required when SFD outputs are enabled
acc_dtype=torch.float32,
d_dtype=torch.bfloat16,
cd_major="n",
mma_tiler_mn=(256, 256),
cluster_shape_mn=(2, 1),
sf_vec_size=32,
vector_f32=False,
m_aligned=256,
discrete_col_sfd=False,
act_func="dswiglu", # or "dgeglu"
b_major="k", # or "n" (fp8 only)
epilogue_op=None, # None/"none"/"identity"/"relu"/"srelu"
current_stream=stream,
)
# dictionary access:
d_row = outputs["d_row_tensor"]
d_col = outputs["d_col_tensor"]
dprob = outputs["dprob_tensor"]
amax = outputs["amax_tensor"]
sfd_row = outputs["sfd_row_tensor"]
sfd_col = outputs["sfd_col_tensor"]
# or tuple unpacking:
d_row, d_col, dprob, amax, sfd_row, sfd_col = outputs
Class API#
from cudnn import DiscreteGroupedGemmDswigluSm100
from cuda.bindings import driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
api = DiscreteGroupedGemmDswigluSm100(
sample_a=sample_a,
num_experts=num_experts,
b_shape=(n, k), # logical (N, K) for one expert
b_dtype=b_dtype,
sample_c=sample_c, # shape (valid_m, N/2, 1)
sample_d_row=sample_d_row,
sample_d_col=sample_d_col,
sample_sfa=sample_sfa,
sample_padded_offsets=sample_padded_offsets,
sample_alpha=sample_alpha,
sample_beta=sample_beta,
sample_prob=sample_prob,
sample_dprob=sample_dprob,
# Optional quantization outputs
sample_sfd_row=sample_sfd_row,
sample_sfd_col=sample_sfd_col,
sample_amax=sample_amax,
sample_norm_const=sample_norm_const,
# Configuration
acc_dtype=torch.float32,
mma_tiler_mn=(256, 256),
cluster_shape_mn=(2, 1),
sf_vec_size=32,
vector_f32=False,
m_aligned=256,
discrete_col_sfd=False,
act_func="dswiglu", # or "dgeglu"
b_major="k", # or "n" (fp8 only)
epilogue_op=None, # None/"none"/"identity"/"relu"/"srelu"
)
assert api.check_support()
api.compile() # descriptor-driven; no runtime tensors required
api.execute(
a_tensor=a_tensor,
b_ptrs=b_ptrs,
c_tensor=c_tensor,
d_row_tensor=d_row_tensor,
d_col_tensor=d_col_tensor,
sfa_tensor=sfa_tensor,
sfb_ptrs=sfb_ptrs,
padded_offsets=padded_offsets,
alpha_tensor=alpha_tensor,
beta_tensor=beta_tensor,
prob_tensor=prob_tensor,
dprob_tensor=dprob_tensor,
sfd_row_tensor=sfd_row_tensor,
sfd_col_tensor=sfd_col_tensor,
amax_tensor=amax_tensor,
norm_const_tensor=norm_const_tensor,
current_stream=stream,
)
Parameters#
Input/Output Tensors#
Input tensor A:
a_tensor(wrapper) orsample_a,a_tensor(class)Shape:
(valid_m, K, 1)Stride:
(K, 1, valid_m*K)- must be K-majorDtype (
ab_dtype):{float4_e2m1fn_x2, uint8, float8_e4m3fn, float8_e5m2}uint8is interpreted as packed FP4
Input tensor B pointers:
b_ptrs(wrapper/class execute)Shape:
(L,)whereL = num_expertsDtype:
int64, CUDA device tensorEach pointer must reference one expert B tensor with logical shape
(N, K)(or(N, K, 1)) and dtypeb_dtypeExpert B layout is controlled by
b_major("k"or"n")
Input tensor SFB pointers:
sfb_ptrs(wrapper/class execute)Shape:
(L,)Dtype:
int64, CUDA device tensorEach pointer must reference one expert SFB tensor with shape
(32, 4, ceil(N/128), 4, ceil(ceil(K/sf_vec_size)/4), 1)
Input tensor C:
c_tensor(wrapper/class)Shape:
(valid_m, N/2, 1)Stride:
(N/2, 1, valid_m*(N/2))- must be N-majorDtype (
c_dtype):{float32, float16, bfloat16}
Output tensor D_row:
d_row_tensor(class) or returned in wrapper dictShape:
(valid_m, N, 1)Stride:
(N, 1, valid_m*N)- must be N-majorDtype (
d_dtype):FP4 inputs:
{float16, bfloat16, float32}FP8 inputs:
{float16, bfloat16, float8_e4m3fn, float8_e5m2, float4_e2m1fn_x2}
Output tensor D_col:
d_col_tensor(class) or returned in wrapper dictShape:
(valid_m, N, 1)Stride:
(N, 1, valid_m*N)- must match D_row (N-major)Dtype: Must match D_row
Input tensor prob:
prob_tensor(wrapper/class)Shape:
(valid_m, 1, 1)Dtype:
float32
Output tensor dprob:
dprob_tensor(wrapper/class)Shape:
(valid_m, 1, 1)Dtype:
float32Must be zero-initialized before kernel execution
Scale factor tensors
SFA (A scale factor):
sfa_tensor(wrapper) orsample_sfa,sfa_tensor(class)Shape:
(32, 4, ceil(valid_m/128), 4, ceil(ceil(K/sf_vec_size)/4), 1)Dtype (
sf_dtype):{float8_e8m0fnu, float8_e4m3fn}
SFD_row (optional):
sfd_row_tensor(wrapper) orsample_sfd_row,sfd_row_tensor(class)Shape:
(32, 4, ceil(valid_m/128), 4, ceil(ceil(N/sf_vec_size)/4), 1)Dtype: Must match SFA
Required when: SFD outputs are enabled
SFD_col (optional):
sfd_col_tensor(wrapper) orsample_sfd_col,sfd_col_tensor(class)Shape:
(32, 4, ceil(N/128), 4, ceil(ceil(valid_m/sf_vec_size)/4), 1)Dtype: Must match SFA
Required when: SFD outputs are enabled
Group offsets
padded_offsets: Cumulative sum of aligned group M sizes
Shape:
(L,)Dtype:
int32padded_offsets[-1] = valid_m
Scaling tensors
alpha: Per-group scaling factors
Shape:
(L,)Dtype:
float32
beta: Per-group scaling for
CShape:
(L,)Dtype:
float32
amax (optional): Per-group max absolute values
Shape:
(L, 2, 1)Dtype:
float32If provided, updated in-place; wrapper auto-allocates it when
d_dtype in {bfloat16, float16}
norm_const (optional): Normalization constant for FP8 quantization
Shape:
(1,)Dtype:
float32Required when:
sfd_row_tensor/sfd_col_tensorare provided
Common Parameters#
acc_dtype: torch.dtypeAccumulator dtype. Must be
torch.float32
mma_tiler_mn: Tuple[int, int]Kernel tile size
(TILE_M, TILE_N). Default:(256, 256)TILE_M in {128, 256}TILE_N = 256
cluster_shape_mn: Tuple[int, int] | NoneThread block cluster shape
(CLUSTER_M, CLUSTER_N)Constraints: positive powers of 2, both <= 4,
CLUSTER_M * CLUSTER_N <= 16Default:
(2, 1)whenTILE_M=256,(1, 1)otherwise
sf_vec_size: intScale factor vector size
Allowed values:
{16, 32}. Default:16
vector_f32: boolEnable packed f32 operations
Default:
False
m_aligned: intAlignment requirement for group M dimension
Must equal
FIX_PAD_SIZE(256) and be divisible bymma_tiler_mn[0]Default:
256
discrete_col_sfd: boolIf True, generate discrete column scale factors grouped by expert tiles
Only applies when SFD outputs are enabled
Default:
False
act_func: strActivation derivative. Valid values:
"dswiglu","dgeglu"Default:
"dswiglu"
b_major: strExpert B layout. Valid values:
"k","n"FP4 inputs require
"k"Default:
"k"
epilogue_op: Optional[str]Optional epilogue transform after backward math
Valid values:
None,"none","identity","relu","srelu"Default:
None
CUDA stream (
current_streamin class API and wrapper)
Wrapper-specific Parameters: discrete_grouped_gemm_dswiglu_wrapper_sm100#
n: int: Logical full N dimension for expert B / D outputsb_dtype: torch.dtype: Dtype of expert B tensors referenced byb_ptrsd_dtype: torch.dtype: Output D tensor dtype. Default:torch.bfloat16cd_major: str: Major dimension for C and D tensors. Must be"n"
Wrapper Return Values#
Returns a TupleDict - a dictionary-like object that also supports tuple unpacking and integer indexing.
Dictionary keys (also tuple unpacking order):
d_row_tensor: Row-quantized dGLU outputd_col_tensor: Column-quantized dGLU outputdprob_tensor: Probability gradient outputamax_tensor: Per-group amax (whend_dtype in {bfloat16, float16})sfd_row_tensor: Row scale factors (when SFD outputs are enabled)sfd_col_tensor: Column scale factors (when SFD outputs are enabled)
Class-specific Parameters#
DiscreteGroupedGemmDswigluSm100 (constructor)#
sample_a,num_experts,b_shape,b_dtype,sample_c,sample_d_row,sample_d_col,sample_sfa,sample_padded_offsets,sample_alpha,sample_beta,sample_prob,sample_dprob,sample_sfd_row,sample_sfd_col,sample_amax,sample_norm_const- see Input/Output tensorsNote:
sample_sfd_row,sample_sfd_col,sample_norm_constmust be allNoneor all notNoneb_shapemust be logical(N, K)for one expert (pass logicalK, not packedK/2, for FP4)
DiscreteGroupedGemmDswigluSm100.execute#
a_tensor,b_ptrs,c_tensor,d_row_tensor,d_col_tensor,sfa_tensor,sfb_ptrs,padded_offsets,alpha_tensor,beta_tensor,prob_tensor,dprob_tensor,sfd_row_tensor,sfd_col_tensor,amax_tensor,norm_const_tensor- see Input/Output tensors. Layouts must match constructor sample descriptors.
Support Surface and Constraints#
Layouts and Strides#
Amust be K-majorExpert
Blayout is selected byb_major:b_major="k": K-majorb_major="n": N-major (FP8 configs only)
C,D_row, andD_colmust be N-majorAll tensors must be 16-byte aligned along the contiguous dimension
Data Types#
Input/Weight Types (ab_dtype)#
Format |
ab_dtype |
sf_dtype |
sf_vec_size |
c_dtype |
d_dtype |
|---|---|---|---|---|---|
MXFP8 |
|
|
32 |
|
|
NVF4 |
|
{ |
{16, 32} |
|
|
Additional Type Constraints#
b_dtypemust matchAdtypeSFA,SFD_row, andSFD_colmust share dtypeD_rowandD_colmust have the same dtypeacc_dtypemust befloat32probanddprobmust befloat32c_dtypemust be one of{float32, float16, bfloat16}sf_dtype=float8_e4m3fnwithsf_vec_size=32is not supportedFP8
ab_dtypewithsf_vec_size=16is not supportedFP4
ab_dtyperequiresb_major="k"
Scale Factor Output Requirements#
When
sfd_row_tensor/sfd_col_tensorare provided:sfd_row_tensor,sfd_col_tensor, andnorm_const_tensorare all requiredThese must be provided together (all
Noneor all notNone)
amax_tensoris optional:If provided, it is updated in-place with per-group maxima
Wrapper auto-allocates it when
d_dtype in {bfloat16, float16}
Tiling and Cluster#
mma_tiler_mn[0] = 256enables 2-CTA instructions (use_2cta_instrs=True)mma_tiler_mn[0] = 128uses the non-2CTA instruction pathWhen
use_2cta_instrs=True:cluster_shape_mn[0]must be divisible by 2m_alignedmust be divisible bymma_tiler_mn[0]m_alignedmust equalFIX_PAD_SIZE=256
Shapes and Divisibility#
Dis produced in 32-column blocks (chooseNdivisible by 64 for standard dGLU layouts)Cis half-width relative toD:shape(C)[1] = N/2padded_offsetslengthLis expert count and must be<= 1024valid_m = padded_offsets[-1]determines actual M sizeb_ptrsandsfb_ptrsmust be CUDAint64tensors with shape(L,)
Environment#
Requires CUDA with SM100+ compute capability (Blackwell GPUs)
Usage Examples#
For runnable examples and validation, see:
test/python/fe_api/test_discrete_grouped_gemm_dswiglu.pytest/python/fe_api/test_discrete_grouped_gemm_dswiglu_utils.py