GEMM + Amax (SM100)#
This is an experimental API and subject to change.
Overview#
Block-scaled GEMM + amax: A persistent, batched dense GEMM on NVIDIA Blackwell GPUs (SM100+) that supports low-precision inputs (FP8, FP4) with per-block scale factors, producing the full GEMM output C and global amax reduction. Implemented with CUTLASS/CUTE.
Inputs: quantized
AandB(FP8 or FP4), and corresponding scale-factor tensorsSFAandSFBthat dequantize along theKdimension in groups of sizesf_vec_size.Outputs: full GEMM result
CandAmax.
Shapes#
Inputs
A: shape(M, K, L)B: shape(N, K, L)SFA: shape(32, 4, ceil_div(M, 128), 4, ceil_div(K, 4·sf_vec_size), L)SFB: shape(32, 4, ceil_div(N, 128), 4, ceil_div(K, 4·sf_vec_size), L)
Outputs
C: shape(M, N, L)Amax: shape(1, 1, 1)
L is the batch dimension.
Equations#
Let block size along K be sf_vec_size ∈ {16, 32}. Dequantization is performed using the provided scale factors for groups of sf_vec_size along K (per M/N blocks defined by the atom tiling):
\( \hat{A}[m, k, l] = \operatorname{dequantize}(A[m, k, l], \text{SFA}, \text{sf_vec_size}) \)
\( \hat{B}[n, k, l] = \operatorname{dequantize}(B[n, k, l], \text{SFB}, \text{sf_vec_size}) \)
\( C[m, n, l] = \sum_{k} \hat{A}[m, k, l] \, \hat{B}[n, k, l] \)
\( \mathrm{Amax} = \max_{m, n, l} |C[m, n, l]| \)
Diagram#
A (M×K×L), SFA B (N×K×L), SFB
│ dequantize(·; SFA) │ dequantize(·; SFB)
▼ ▼
 (M×K×L) B̂ (N×K×L)
└── GEMM over K ─────────────────────────┐
C (M×N×L or packed)
│
├── reduce: Amax = max |C|
▼
Amax (1×1×1)
API Usage#
High-level wrapper#
c, amax = gemm_amax_wrapper_sm100(
a_tensor,
b_tensor,
sfa_tensor,
sfb_tensor,
c_major="n",
c_dtype=torch.float32,
acc_dtype=torch.float32,
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
sf_vec_size=32,
stream=None,
)
Class API#
from cuda.bindings import driver as cuda
op = GemmAmaxSm100(
sample_a=a,
sample_b=b,
sample_sfa=sfa,
sample_sfb=sfb,
sample_c=c,
sample_amax=amax,
acc_dtype=torch.float32,
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
sf_vec_size=32,
)
assert op.check_support()
op.compile(current_stream=None)
op.execute(a, b, sfa, sfb, c, amax, current_stream=None)
Parameters#
Input/Output tensors#
Input tensor A:
a_tensor(wrapper) orsample_a/a_tensor(class)Shape:
(M, K, L)Stride:
(1, M, M·K)form-major or(K, 1, M·K)fork-majorDtype:
{float4_e2m1fn_x2, uint8, float8_e4m3fn, float8_e5m2}(uint8is interpreted as packed fp4x2)
Input tensor B:
b_tensor(wrapper) orsample_b/b_tensor(class)Shape:
(N, K, L)Stride:
(1, N, N·K)forn-major or(K, 1, N·K)fork-majorDtype: Must match
A
Input tensor SFA:
sfa_tensor(wrapper) orsample_sfa/sfa_tensor(class)Shape:
(ATOM_M0, ATOM_M1, ceil_div(M, ATOM_M0·ATOM_M1), ATOM_K, ceil_div(K, ATOM_K·sf_vec_size), L)Dtype:
{float8_e8m0fnu, float8_e4m3fn, int8}(int8is interpreted asfloat8_e8m0fnu)
Input tensor SFB:
sfb_tensor(wrapper) orsample_sfb/sfb_tensor(class)Shape:
(ATOM_M0, ATOM_M1, ceil_div(N, ATOM_M0·ATOM_M1), ATOM_K, ceil_div(K, ATOM_K·sf_vec_size), L)Dtype:
{float8_e8m0fnu, float8_e4m3fn, int8}(int8is interpreted asfloat8_e8m0fnu)
Output tensor C: return value (wrapper) or
sample_c/c_tensor(class)Shape:
(M, N, L)Stride: (1, M, M·N)
form-major or(N, 1, M·N)forn-major. Provided asc_major` argument for wrapperDtype:
{float32, float16, bfloat16, float8_e5m2, float8_e4m3fn, float4_e2m1fn_x2, uint8}. Provided asc_dtypeargument for wrapper
Output tensor Amax: return value (wrapper) or
sample_amax/amax_tensor(class)Shape:
(1, 1, 1)Dtype:
float32
Common parameters#
acc_dtype: torch.dtypeAccumulator dtype. Default:
torch.float32(only supported value)
mma_tiler_mn: Tuple[int, int]Kernel tile size
(TILE_M, TILE_N). Default:(128, 128)TILE_M ∈ {128};TILE_M = 256is currently disabledTILE_N ∈ {128, 256}
cluster_shape_mn: Tuple[int, int]Thread Block cluster shape
(CLUSTER_M, CLUSTER_N). Default:(1, 1)Constraints: values in
{1, 2, 4}
sf_vec_size: intSize of K-group per scale factor:
{16, 32}. Default:32
CUDA stream (
current_streamin class API,streamin wrapper)
Wrapper-specific parameters: gemm_amax_wrapper_sm100#
a_tensor,b_tensor,sfa_tensor,sfb_tensor: see Input/Output tensorsc_major: str: see Input/Output tensors. Default:"n"c_dtype: torch.dtype: see Input/Output tensors. Default:torch.float32
Class-specific parameters: GemmAmaxSm100#
GemmAmaxSm100 (constructor)#
sample_a,sample_b,sample_sfa,sample_sfb,sample_c,sample_amax: see Input/Output tensors
GemmAmaxSm100.execute#
a_tensor,b_tensor,sfa_tensor,sfb_tensor,c_tensor,amax_tensor: see Input/Output tensorsskip_compile: bool— Default:False
Support surface and constraints#
Layouts and strides#
For
A/B ∈ {float4_e2m1fn_x2, uint8}(packed FP4),AandBmust bek-major.For
C ∈ {float4_e2m1fn_x2, uint8}(packed FP4),Cmust ben-major.For all
float4_e2m1fn_x2/uint8cases, the innermost tensor dimension will be divided by 2 due to 2x packing. i.e.Awould be shaped(M, K // 2, L)instead of(M, K, L).A,B,Cmust be 16-byte aligned along the contiguous dimension.
Dtypes#
A/Bmust have the same dtype.sf_vec_size ∈ {16, 32}with coupling:sf_dtype == float8_e4m3fnis unsupported withsf_vec_size == 32A/B ∈ {float8_e4m3fn, float8_e5m2}is unsupported withsf_vec_size == 16
A/B ∈ FP8andC ∈ FP8together are currently disabledC ∈ {float4_e2m1fn_x2, uint8}requiresA/B ∈ {float4_e2m1fn_x2, uint8}
Tiling and cluster#
A/B ∈ {float4_e2m1fn_x2, uint8}andN_tile == 256requiresK > 128mma_tiler_mn == (128, 256),sf_vec_size == 16,C ∈ {float32, float16, bfloat16}is currently disabled
Shapes and divisibility#
SFA/SFBshapes must follow the atom tiling andsf_vec_sizerules aboveWhen
Cis packed FP4, use(M, ceil_div(N, 2), L)andn-major strides
Environment#
Requires CUDA with SM100+ compute capability
Usage examples#
For usage examples, see test cases in test/python/fe_api/test_gemm_amax.py