cuda.tile.mma_scaled#
- cuda.tile.mma_scaled(x, x_scale, y, y_scale, /, acc)#
Block-scaled matrix multiply-accumulate.
Computes a matrix multiply-accumulate where inputs are scaled by block scales along the K dimension before the mma:
result[i, j] = sum(x[i, k] * x_scale[i, k // B] * y[k, j] * y_scale[k // B, j] for k in range(K)) + acc[i, j]
The scaling block size is
B = K // K_s, whereK_sis the K dimension of the scale tile.Kmust be divisible byK_s, andBmust be one of the allowed values listed in the table below.- Parameters:
x (Tile) – LHS input, 2D or 3D
[..., M, K].x_scale (Tile) – Scale factors for x, shape
[..., M, K_s]. All dimensions except K_s must match x exactly.y (Tile) – RHS input, 2D or 3D
[..., K, N].y_scale (Tile) – Scale factors for y, shape
[..., K_s, N]. All dimensions except K_s must match y exactly.acc (Tile) – Accumulator
[..., M, N].
Supported datatypes and scaling block sizes:
Input (x/y)
Scale
Acc/Out
B
f8e4m3fn, f8e5m2
f8e8m0fnu
f32
32
f4e2m1fn
f8e8m0fnu
f32
16, 32
f4e2m1fn
f8e4m3fn
f32
16
Batch dimensions of x and y are broadcast against each other (same as
mma()). x_scale’s batch dimension must match x’s batch exactly, and y_scale’s batch dimension must match y’s batch exactly; both are then broadcast to the output batch shape.- Return type:
Examples
Basic usage.
tx = ct.ones((2, 64), ct.float8_e4m3fn) sx = ct.full((2, 2), 2.0, ct.float8_e8m0fnu) ty = ct.ones((64, 2), ct.float8_e4m3fn) sy = ct.full((2, 2), 2.0, ct.float8_e8m0fnu) acc = ct.zeros((2, 2), ct.float32) tz = ct.mma_scaled(tx, sx, ty, sy, acc) print(f"{tz:.1f}")
import cuda.tile as ct import torch @ct.kernel def kernel(): tx = ct.ones((2, 64), ct.float8_e4m3fn) sx = ct.full((2, 2), 2.0, ct.float8_e8m0fnu) ty = ct.ones((64, 2), ct.float8_e4m3fn) sy = ct.full((2, 2), 2.0, ct.float8_e8m0fnu) acc = ct.zeros((2, 2), ct.float32) tz = ct.mma_scaled(tx, sx, ty, sy, acc) print(f"{tz:.1f}") torch.cuda.init() ct.launch(torch.cuda.current_stream(), (1,), kernel, ()) torch.cuda.synchronize()
Output
[[256.0, 256.0], [256.0, 256.0]]
For best performance on sm_100.
m1, m2, k1 = 32, 4, 4 def swizzle_32_4_4(scale): ''' Prepare the original scale tensor to align with the expected tmem layout. With the innermost dimensions being (m1=32, m2=4, k1=4), and the outer dimensions being (m0=(M // m1 * m2), k0=(K_s // k1)). Reference: PTX ISA tcgen05.mma scale factor layout. https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x ''' M, K_s = scale.shape m0 = M // (m1 * m2) k0 = K_s // k1 scale = scale.reshape(m0, m2, m1, k0, k1).permute(0, 3, 2, 1, 4).contiguous() return scale.reshape(m0, k0, m1, m2 * k1) def unswizzle_32_4_4(tile): ''' Kernel-side inverse of ``swizzle_32_4_4``: take a tile loaded from the host swizzled scale tensor and recover the ``(M, K_s)`` view that ``ct.mma_scaled`` expects. ''' m0, k0, _, _ = tile.shape m1, m2, k1 = (32, 4, 4) return (tile.reshape((m0, k0, m1, m2, k1)) .permute((0, 3, 2, 1, 4)) .reshape((m0 * m1 * m2, k0 * k1))) @ct.kernel def kernel(X, X_scale, Y, Y_scale, Z, TM: ct.Constant[int], TN: ct.Constant[int], TK: ct.Constant[int]): x = ct.load(X, index=(0, 0), shape=(TM, TK)) y = ct.load(Y, index=(0, 0), shape=(TN, TK)).transpose() x_s = ct.load(X_scale, index=(0, 0, 0, 0), shape=(1, 1, m1, m2 * k1)) x_s = unswizzle_32_4_4(x_s) y_s = ct.load(Y_scale, index=(0, 0, 0, 0), shape=(1, 1, m1, m2 * k1)) y_s = unswizzle_32_4_4(y_s).transpose() acc = ct.zeros((TM, TN), ct.float32) ct.store(Z, index=(0, 0), tile=ct.mma_scaled(x, x_s, y, y_s, acc)) M, N, K = 128, 128, 128 SCALE_BLOCK_SIZE = 32 K_s = K // SCALE_BLOCK_SIZE TM, TN, TK = 128, 128, 128 X = torch.ones((M, K), device='cuda').to(torch.float8_e4m3fn) Y = torch.ones((N, K), device='cuda').to(torch.float8_e4m3fn) X_scale = torch.full((M, K_s), 2.0, device='cuda').to(torch.float8_e8m0fnu) X_scale = swizzle_32_4_4(X_scale) Y_scale = torch.full((N, K_s), 2.0, device='cuda').to(torch.float8_e8m0fnu) Y_scale = swizzle_32_4_4(Y_scale) Z = torch.zeros((M, N), dtype=torch.float32, device='cuda') ct.launch(stream, (1, 1, 1), kernel, (X, X_scale, Y, Y_scale, Z, TM, TN, TK)) torch.cuda.synchronize() print(Z.unique().tolist())
import cuda.tile as ct import torch torch.cuda.init() stream = torch.cuda.current_stream() m1, m2, k1 = 32, 4, 4 def swizzle_32_4_4(scale): ''' Prepare the original scale tensor to align with the expected tmem layout. With the innermost dimensions being (m1=32, m2=4, k1=4), and the outer dimensions being (m0=(M // m1 * m2), k0=(K_s // k1)). Reference: PTX ISA tcgen05.mma scale factor layout. https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x ''' M, K_s = scale.shape m0 = M // (m1 * m2) k0 = K_s // k1 scale = scale.reshape(m0, m2, m1, k0, k1).permute(0, 3, 2, 1, 4).contiguous() return scale.reshape(m0, k0, m1, m2 * k1) def unswizzle_32_4_4(tile): ''' Kernel-side inverse of ``swizzle_32_4_4``: take a tile loaded from the host swizzled scale tensor and recover the ``(M, K_s)`` view that ``ct.mma_scaled`` expects. ''' m0, k0, _, _ = tile.shape m1, m2, k1 = (32, 4, 4) return (tile.reshape((m0, k0, m1, m2, k1)) .permute((0, 3, 2, 1, 4)) .reshape((m0 * m1 * m2, k0 * k1))) @ct.kernel def kernel(X, X_scale, Y, Y_scale, Z, TM: ct.Constant[int], TN: ct.Constant[int], TK: ct.Constant[int]): x = ct.load(X, index=(0, 0), shape=(TM, TK)) y = ct.load(Y, index=(0, 0), shape=(TN, TK)).transpose() x_s = ct.load(X_scale, index=(0, 0, 0, 0), shape=(1, 1, m1, m2 * k1)) x_s = unswizzle_32_4_4(x_s) y_s = ct.load(Y_scale, index=(0, 0, 0, 0), shape=(1, 1, m1, m2 * k1)) y_s = unswizzle_32_4_4(y_s).transpose() acc = ct.zeros((TM, TN), ct.float32) ct.store(Z, index=(0, 0), tile=ct.mma_scaled(x, x_s, y, y_s, acc)) M, N, K = 128, 128, 128 SCALE_BLOCK_SIZE = 32 K_s = K // SCALE_BLOCK_SIZE TM, TN, TK = 128, 128, 128 X = torch.ones((M, K), device='cuda').to(torch.float8_e4m3fn) Y = torch.ones((N, K), device='cuda').to(torch.float8_e4m3fn) X_scale = torch.full((M, K_s), 2.0, device='cuda').to(torch.float8_e8m0fnu) X_scale = swizzle_32_4_4(X_scale) Y_scale = torch.full((N, K_s), 2.0, device='cuda').to(torch.float8_e8m0fnu) Y_scale = swizzle_32_4_4(Y_scale) Z = torch.zeros((M, N), dtype=torch.float32, device='cuda') ct.launch(stream, (1, 1, 1), kernel, (X, X_scale, Y, Y_scale, Z, TM, TN, TK)) torch.cuda.synchronize() print(Z.unique().tolist()) torch.cuda.synchronize()
Output
[512.0]