Grouped GEMM + sReLU (SM100)#
This is an experimental API and subject to change.
Overview#
Grouped GEMM + sReLU fusion: A grouped block-scaled GEMM fused with a probability-gated squared-ReLU epilogue on NVIDIA Blackwell GPUs (SM100+), designed for MoE-style workloads. The API supports dense contiguous weights and discrete per-expert weight allocations. Groups are contiguous in the M dimension and described by padded_offsets.
This kernel performs:
Block-scaled grouped GEMM over contiguous expert ranges
sReLU epilogue using per-row
probOptional output quantization through
SFD_row/SFD_colorAmax
Shapes#
Inputs
A: contiguous activation tensor across all groups, shape(valid_m, K, 1)B: dense weight tensor across all groups, shape(N, K, L), or discrete per-expert tensors addressed byb_ptrsSFA: shape(32, 4, ceil_div(valid_m, 128), 4, ceil_div(ceil_div(K, sf_vec_size), 4), 1)SFB: dense scale-factor tensor, shape(32, 4, ceil_div(N, 128), 4, ceil_div(ceil_div(K, sf_vec_size), 4), L), or discrete per-expert tensors addressed bysfb_ptrspadded_offsets: cumulative padded group ends, shape(L,)alpha: per-group scaling factors, shape(L,)prob: per-row gating probabilities, shape(valid_m, 1, 1)
Outputs
C: intermediate GEMM result, shape(valid_m, N, 1)D: row output after sReLU, shape(valid_m, N, 1)D_col: column output after sReLU, shape(valid_m, N, 1)SFD_row: shape(32, 4, ceil_div(valid_m, 128), 4, ceil_div(ceil_div(N, sf_vec_size), 4), 1)whenDis FP8SFD_col: shape(32, 4, ceil_div(N, 128), 4, ceil_div(ceil_div(valid_m, sf_vec_size), 4), 1)whenDis FP8Amax: shape(L, 1)whenDis fp16/bf16
L is the expert count and valid_m = padded_offsets[-1].
Equations#
For rows belonging to expert g:
\( C[m, n] = \alpha_g \sum_k \mathrm{dequantize}(A[m, k], SFA) \cdot \mathrm{dequantize}(B[n, k, g], SFB) \)
\( D[m, n] = \mathrm{prob}[m, 0, 0] \cdot \mathrm{relu}(C[m, n])^2 \)
D_col stores the same logical output in the column-quantized path used by the grouped kernel family. When D is FP8, the kernel also emits SFD_row and SFD_col. When D is fp16/bf16, the kernel can emit per-expert Amax.
Diagram#
A (valid_m×K×1), SFA B (N×K×L), SFB padded_offsets
| | |
| dequantize | |
+----------+-----------+ |
v v
Grouped GEMM over expert ranges --> group idx
|
| * alpha[group_idx]
v
C (valid_m×N×1)
|
| relu(C)^2 * prob
v
D / D_col (valid_m×N×1)
|
+----------+-----------+
| |
v v
SFD_row/SFD_col Amax
API Usage#
High-level wrapper#
from cudnn import grouped_gemm_srelu_wrapper_sm100
result = grouped_gemm_srelu_wrapper_sm100(
a_tensor=a,
b_tensor=b,
sfa_tensor=sfa,
sfb_tensor=sfb,
padded_offsets=padded_offsets,
alpha_tensor=alpha,
norm_const_tensor=norm_const,
prob_tensor=prob,
acc_dtype=torch.float32,
c_dtype=torch.bfloat16,
d_dtype=torch.bfloat16,
cd_major="n",
mma_tiler_mn=(256, 256),
cluster_shape_mn=(2, 1),
sf_vec_size=32,
vector_f32=False,
m_aligned=256,
discrete_col_sfd=False,
current_stream=None,
)
c, d, d_col, amax, sfd_row, sfd_col = result
Discrete-weight wrapper#
result = grouped_gemm_srelu_wrapper_sm100(
a_tensor=a,
sfa_tensor=sfa,
padded_offsets=padded_offsets,
alpha_tensor=alpha,
b_ptrs=b_ptrs, # int64 device tensor of per-expert B pointers
sfb_ptrs=sfb_ptrs, # int64 device tensor of per-expert SFB pointers
n=N,
b_dtype=torch.float4_e2m1fn_x2,
b_major="k",
prob_tensor=prob,
c_dtype=torch.bfloat16,
d_dtype=torch.bfloat16,
)
Class API#
from cudnn import GroupedGemmSreluSm100
op = GroupedGemmSreluSm100(
sample_a=a,
sample_b=b,
sample_c=c,
sample_d=d,
sample_sfa=sfa,
sample_sfb=sfb,
sample_padded_offsets=padded_offsets,
sample_alpha=alpha,
sample_d_col=d_col,
sample_sfd_row=sfd_row,
sample_sfd_col=sfd_col,
sample_amax=amax,
sample_norm_const=norm_const,
sample_prob=prob,
acc_dtype=torch.float32,
mma_tiler_mn=(256, 256),
cluster_shape_mn=(2, 1),
sf_vec_size=32,
vector_f32=False,
m_aligned=256,
discrete_col_sfd=False,
)
assert op.check_support()
op.compile()
op.execute(
a_tensor=a,
b_tensor=b,
c_tensor=c,
d_tensor=d,
sfa_tensor=sfa,
sfb_tensor=sfb,
padded_offsets=padded_offsets,
alpha_tensor=alpha,
d_col_tensor=d_col,
sfd_row_tensor=sfd_row,
sfd_col_tensor=sfd_col,
amax_tensor=amax,
norm_const_tensor=norm_const,
prob_tensor=prob,
current_stream=None,
)
Parameters#
Input/Output tensors#
Input tensor A:
a_tensor(wrapper) orsample_a/a_tensor(class)Shape:
(valid_m, K, 1)Layout: must be
k-majorDtype:
{float4_e2m1fn_x2, uint8, float8_e4m3fn, float8_e5m2}Note:
uint8is interpreted as packedfloat4_e2m1fn_x2(FP4x2) data, not integer quantization
Input tensor B:
b_tensor(wrapper) orsample_b/b_tensor(class)Shape:
(N, K, L)Layout: must be
k-majorDtype: must match
A
Discrete input B pointers:
b_ptrs(wrapper) ornum_experts/b_shape/b_dtype(class)b_ptrs: 1-Dint64CUDA tensor containing one data pointer per expertnandb_dtypeare required in wrapper discrete modeb_majormay be"k"or"n"for supported FP8 cases; FP4 uses"k"
Input tensor SFA:
sfa_tensor(wrapper) orsample_sfa/sfa_tensor(class)Shape:
(32, 4, ceil_div(valid_m, 128), 4, ceil_div(ceil_div(K, sf_vec_size), 4), 1)Dtype:
{float8_e8m0fnu, float8_e4m3fn}
Input tensor SFB:
sfb_tensor(wrapper) orsample_sfb/sfb_tensor(class)Shape:
(32, 4, ceil_div(N, 128), 4, ceil_div(ceil_div(K, sf_vec_size), 4), L)Dtype: must match
SFA
Discrete input SFB pointers:
sfb_ptrs1-D
int64CUDA tensor containing one scale-factor pointer per expert
Input tensor padded_offsets
Shape:
(L,)Dtype:
int32
Input tensor alpha
Shape:
(L,)Dtype:
float32
Input tensor prob
Shape:
(valid_m, 1, 1)Dtype:
float32Required
Output tensor C:
result["c_tensor"](wrapper) orsample_c/c_tensor(class)Shape:
(valid_m, N, 1)Layout: must be
n-majorDtype:
{float32, float16, bfloat16, float8_e4m3fn, float8_e5m2, float4_e2m1fn_x2}
Output tensor D:
result["d_tensor"](wrapper) orsample_d/d_tensor(class)Shape:
(valid_m, N, 1)Layout: must be
n-majorDtype:
FP4 input:
{float16, bfloat16, float32}FP8 input:
{float16, bfloat16, float8_e4m3fn, float8_e5m2, float4_e2m1fn_x2}
Output tensor D_col:
result["d_col_tensor"](wrapper) orsample_d_col/d_col_tensor(class)Shape:
(valid_m, N, 1)Layout: must match
DDtype: must match
D
Output tensors SFD_row / SFD_col
Dtypes: must match
SFARequired when FP8 output scale factors are generated
Output tensor Amax
Shape:
(L, 1)Dtype:
float32
Input tensor Norm Const
Shape:
(1,)Dtype:
float32Required when FP8 output scale factors are generated
Common parameters#
acc_dtype: torch.dtypeOnly
torch.float32is supported
mma_tiler_mn: Tuple[int, int]TILE_Mdepends on the 1-CTA / 2-CTA modeTILE_N ∈ {128, 256}
cluster_shape_mn: Tuple[int, int] | NoneDefault:
(2, 1)whenTILE_M == 256, else(1, 1)
sf_vec_size: intAllowed values:
{16, 32}
vector_f32: boolEnables vectorized f32 operations for supported configurations
m_aligned: intMust equal the kernel fixed pad size
256
discrete_col_sfd: boolEnables the discrete column-scale-factor path used by grouped FP8
CUDA stream (
current_streamin class API,current_streamin wrapper)
Wrapper return values#
Returns a TupleDict with keys:
c_tensord_tensord_col_tensoramax_tensorsfd_row_tensorsfd_col_tensor
Tuple unpacking order is: (c_tensor, d_tensor, d_col_tensor, amax_tensor, sfd_row_tensor, sfd_col_tensor).
Support surface and constraints#
Layouts#
Amust bek-majorBmust bek-majorDiscrete
Bsupportsb_major="k"and supported FP8b_major="n"configurationsC,D, andD_colmust ben-majorThe wrapper only supports
cd_major="n"
Dtypes#
AandBmust have the same dtypeSFA,SFB,SFD_row, andSFD_colmust have the same dtypeScale-factor dtype constraint:
sf_vec_size == 32is unsupported whensf_dtype == float8_e4m3fnInput dtype constraint: FP8
A/Binputs requiresf_vec_size == 32Grouped FP8 currently requires
discrete_col_sfd=True
Shapes and environment#
prob_tensoris requiredm_alignedmust be256Requires CUDA with SM100+ compute capability
Usage examples#
For end-to-end usage and regression coverage, see:
test/python/fe_api/test_grouped_gemm_srelu.pytest/python/fe_api/test_grouped_gemm_srelu_utils.py