Grouped GEMM + GLU + Hadamard (SM100)#
This is an experimental API and subject to change.
Overview#
Grouped GEMM + GLU + Hadamard fusion: A contiguous grouped block-scaled GEMM fused with a GLU epilogue, a 16-wide Hadamard transform, and per-expert amax reduction on NVIDIA Blackwell GPUs (SM100+), designed for MoE-style workloads. Groups are contiguous in the M dimension and described by padded_offsets.
This frontend integration is currently wired for the fp4 input path.
This kernel performs:
Block-scaled grouped GEMM over contiguous expert ranges
GLU epilogue using per-row
probHadamard transform across the post-GLU output
Per-expert amax reduction on the final output
Shapes#
Inputs
A: contiguous activation tensor across all groups, shape(valid_m, K, 1)B: weight tensor across all groups, shape(N, K, L)SFA: shape(32, 4, ceil_div(valid_m, 128), 4, ceil_div(ceil_div(K, sf_vec_size), 4), 1)SFB: shape(32, 4, ceil_div(N, 128), 4, ceil_div(ceil_div(K, sf_vec_size), 4), L)padded_offsets: cumulative padded group ends, shape(L,)alpha: per-group scaling factors, shape(L,)prob: per-row gating probabilities, shape(valid_m, 1, 1)bias(optional): per-expert bias tensor, shape(N, L)with stride(1, N)Hadamard: fixed transform matrix, shape(16, 16)
Outputs
C: intermediate GEMM result before GLU/Hadamard, shape(valid_m, N, 1)D: output after GLU and Hadamard, shape(valid_m, N / 2, 1)Amax: per-expert amax, shape(L, 1)whenDis fp16/bf16
L is the expert count and valid_m = padded_offsets[-1].
Equations#
For rows belonging to expert g:
\( C[m, n] = \alpha_g \sum_k \mathrm{dequantize}(A[m, k], SFA) \cdot \mathrm{dequantize}(B[n, k, g], SFB) \)
Split the N dimension into consecutive 32-column gate/up blocks:
\( G_b = C[:, 2bG:(2b+1)G], \quad U_b = C[:, (2b+1)G:(2b+2)G], \quad G = 32 \)
For SwiGLU (act_func="swiglu"):
\( X[:, bG:(b+1)G] = \mathrm{prob} \cdot U_b \cdot \left(G_b \cdot \sigma(G_b)\right) \)
For GeGLU (act_func="geglu"):
\( X[:, bG:(b+1)G] = \mathrm{prob} \cdot (U_b + 1) \cdot G_b \cdot \sigma(1.702 \cdot G_b) \)
Apply the fixed Hadamard matrix H of size 16 x 16 blockwise over the output:
\( D = X \cdot H \)
When D is fp16/bf16, the kernel also emits per-expert Amax.
Diagram#
A (valid_m×K×1), SFA B (N×K×L), SFB padded_offsets
| | |
| dequantize | |
+----------+-----------+ |
v v
Grouped GEMM over expert ranges --> group idx
|
| * alpha[group_idx]
v
C (valid_m×N×1)
|
| GLU over paired 32-col blocks
| with per-row prob
v
X (valid_m×N/2×1)
|
| blockwise Hadamard(16)
v
D (valid_m×N/2×1)
|
v
Amax (L×1)
API Usage#
High-level wrapper#
from cudnn import grouped_gemm_glu_hadamard_wrapper_sm100
result = grouped_gemm_glu_hadamard_wrapper_sm100(
a_tensor=a,
b_tensor=b,
sfa_tensor=sfa,
sfb_tensor=sfb,
padded_offsets=padded_offsets,
alpha_tensor=alpha,
prob_tensor=prob,
bias_tensor=bias,
acc_dtype=torch.float32,
c_dtype=torch.bfloat16,
d_dtype=torch.bfloat16,
cd_major="n",
mma_tiler_mn=(256, 256),
cluster_shape_mn=(2, 1),
sf_vec_size=16,
vector_f32=False,
m_aligned=256,
act_func="swiglu",
current_stream=None,
)
c_tensor, d_tensor, amax_tensor = result
The wrapper constructs the fixed Hadamard matrix internally.
Class API#
from cudnn import GroupedGemmGluHadamardSm100
op = GroupedGemmGluHadamardSm100(
sample_a=a,
sample_b=b,
sample_c=c,
sample_d=d,
sample_sfa=sfa,
sample_sfb=sfb,
sample_padded_offsets=padded_offsets,
sample_alpha=alpha,
sample_prob=prob,
sample_amax=amax,
sample_bias=bias,
acc_dtype=torch.float32,
mma_tiler_mn=(256, 256),
cluster_shape_mn=(2, 1),
sf_vec_size=16,
vector_f32=False,
m_aligned=256,
act_func="swiglu",
)
assert op.check_support()
op.compile()
op.execute(
a_tensor=a,
b_tensor=b,
c_tensor=c,
d_tensor=d,
sfa_tensor=sfa,
sfb_tensor=sfb,
padded_offsets=padded_offsets,
alpha_tensor=alpha,
prob_tensor=prob,
amax_tensor=amax,
bias_tensor=bias,
current_stream=None,
)
You may optionally pass a custom sample_hadamard / hadamard_tensor, but the API normalizes it to the fixed 16 x 16 bf16 contiguous layout expected by the kernel. If you do not provide one, the default is the fixed kernel matrix.
Parameters#
Input/Output tensors#
Input tensor A:
a_tensor(wrapper) orsample_a/a_tensor(class)Shape:
(valid_m, K, 1)Layout: must be
k-majorDtype:
{float4_e2m1fn_x2, uint8}
Input tensor B:
b_tensor(wrapper) orsample_b/b_tensor(class)Shape:
(N, K, L)Layout: must be
k-majorDtype: must match
A
Input tensor SFA:
sfa_tensor(wrapper) orsample_sfa/sfa_tensor(class)Shape:
(32, 4, ceil_div(valid_m, 128), 4, ceil_div(ceil_div(K, sf_vec_size), 4), 1)Dtype:
{float8_e8m0fnu, float8_e4m3fn}
Input tensor SFB:
sfb_tensor(wrapper) orsample_sfb/sfb_tensor(class)Shape:
(32, 4, ceil_div(N, 128), 4, ceil_div(ceil_div(K, sf_vec_size), 4), L)Dtype: must match
SFA
Input tensor padded_offsets
Shape:
(L,)Dtype:
int32
Input tensor alpha
Shape:
(L,)Dtype:
float32
Input tensor prob
Shape:
(valid_m, 1, 1)Dtype:
float32
Input tensor bias (optional)
Shape:
(N, L)Stride:
(1, N)Dtype:
{float16, bfloat16, float32}
Input tensor Hadamard (optional in class API)
Shape:
(16, 16)Dtype:
bfloat16Layout: normalized to a contiguous
16 x 16bf16 tensor before compile/execute
Output tensor C:
result["c_tensor"](wrapper) orsample_c/c_tensor(class)Shape:
(valid_m, N, 1)Layout: must be
n-majorDtype:
{float16, bfloat16}
Output tensor D:
result["d_tensor"](wrapper) orsample_d/d_tensor(class)Shape:
(valid_m, N / 2, 1)Layout: must be
n-majorDtype:
{float16, bfloat16}
Output tensor Amax:
result["amax_tensor"](wrapper) orsample_amax/amax_tensor(class)Shape:
(L, 1)Dtype:
float32
Common parameters#
acc_dtype: torch.dtypeOnly
torch.float32is supported
mma_tiler_mn: Tuple[int, int]Must be
(256, 256)
cluster_shape_mn: Tuple[int, int] | NoneDefault:
(2, 1)
sf_vec_size: intAllowed values:
{16, 32}
vector_f32: boolEnables vectorized f32 operations for supported configurations
m_aligned: intMust equal the kernel fixed pad size
256
act_func: strAllowed values:
{"swiglu", "geglu"}
CUDA stream (
current_streamin class API and wrapper)
Wrapper return values#
Returns a TupleDict with keys:
c_tensord_tensoramax_tensor
Tuple unpacking order is: (c_tensor, d_tensor, amax_tensor).
Support surface and constraints#
Only dense contiguous grouped weights are exposed in this frontend integration.
The wrapper constructs the fixed Hadamard matrix internally.
AandBmust be fp4 input tensors.Dis currently supported for{float16, bfloat16}.Nmust be divisible by64.N / 2must be divisible by16.m_alignedmust be256.expert_cntmust be<= 1024.The kernel requires SM100+.