RMSNorm + RHT + Amax (SM100)#
This is an experimental API and subject to change.
Overview#
RMSNorm + RHT + amax: A fused CUTE DSL kernel for NVIDIA Blackwell GPUs (SM100+) that applies RMS normalization, a block-diagonal Hadamard transform with fixed block size 16, and a per-CTA amax reduction.
This frontend integration exposes the kernel as a standard FE-OSS Python API with:
a class API (
RmsNormRhtAmaxSm100)a wrapper API (
rmsnorm_rht_amax_wrapper_sm100)grouped-gemm-style regression coverage for compile/execute, wrapper use, and cache reuse
Shapes#
Inputs
X: activation tensor, shape(M, N)W: RMSNorm scale tensor, shape(N,)
Outputs
O: fused RMSNorm + RHT output tensor, shape(M, N)Amax: per-CTA max-abs tensor, shape(M / rows_per_cta,)
rows_per_cta is the number of rows reduced into each amax element.
Equations#
For each row m:
\( \mathrm{RMS}(X_m) = \sqrt{\frac{1}{N}\sum_{n=0}^{N-1} X[m, n]^2 + \epsilon} \)
\( Y[m, n] = \frac{X[m, n]}{\mathrm{RMS}(X_m)} \cdot W[n] \)
Then apply the fixed Hadamard transform blockwise over 16-wide chunks:
\( O[m, b] = Y[m, b] \cdot H_{16} / \sqrt{16} \)
where H_16 is the 16 x 16 Hadamard matrix and b indexes each 16-element block in the hidden dimension.
For each CTA covering rows_per_cta rows:
\( \mathrm{Amax}[c] = \max |O| \)
over every element produced by that CTA.
API Usage#
High-level wrapper#
from cudnn import rmsnorm_rht_amax_wrapper_sm100
result = rmsnorm_rht_amax_wrapper_sm100(
x_tensor=x,
w_tensor=w,
eps=1e-5,
num_threads=None, # optional override
rows_per_cta=None, # optional override
current_stream=None,
)
o_tensor, amax_tensor = result
When no overrides are supplied, the wrapper uses the upstream-tuned thread table when available and an upstream-style rows_per_cta heuristic.
Class API#
from cudnn import RmsNormRhtAmaxSm100
op = RmsNormRhtAmaxSm100(
sample_x=x,
sample_w=w,
sample_o=o,
sample_amax=amax,
eps=1e-5,
num_threads=128,
rows_per_cta=2,
)
assert op.check_support()
op.compile()
op.execute(
x_tensor=x,
w_tensor=w,
o_tensor=o,
amax_tensor=amax,
current_stream=None,
)
Parameters#
Input and output tensors#
x_tensor/sample_xShape:
(M, N)Layout: row-major contiguous
Dtype:
torch.bfloat16
w_tensor/sample_wShape:
(N,)Layout: contiguous
Dtype:
torch.bfloat16
o_tensor/sample_oShape:
(M, N)Layout: row-major contiguous
Dtype:
torch.bfloat16
amax_tensor/sample_amaxShape:
(M / rows_per_cta,)Dtype:
torch.float32
Common parameters#
eps: floatRMSNorm epsilon. Default:
1e-5
num_threads: Optional[int]Threads per CTA. If omitted, the API uses the upstream-tuned table when possible, otherwise a valid fallback search.
rows_per_cta: Optional[int]Rows processed by each CTA. If omitted, the wrapper uses the upstream-style heuristic over
{2, 4, 8}.
CUDA stream (
current_stream)
Wrapper return values#
Returns a TupleDict with keys:
o_tensoramax_tensor
Tuple unpacking order is (o_tensor, amax_tensor).
Support surface and constraints#
Requires SM100+.
Nmust be divisible by16.Nmust be divisible by the resolvednum_threads.EPT = N / num_threadsmust be at least8and divisible by8.Mmust be divisible byrows_per_cta.Inputs and output are currently bf16 only.
The frontend integration matches the upstream RMSNorm kernel semantics; it does not expose full LayerNorm mean/bias behavior.
Verification#
Focused correctness and cache coverage live in:
test/python/fe_api/test_rmsnorm_rht_amax.py