cuda.tile.mma_scaled#

cuda.tile.mma_scaled(x, x_scale, y, y_scale, /, acc)#

Block-scaled matrix multiply-accumulate.

Computes a matrix multiply-accumulate where inputs are scaled by block scales along the K dimension before the mma:

result[i, j] = sum(x[i, k] * x_scale[i, k // B] * y[k, j] * y_scale[k // B, j]
                   for k in range(K)) + acc[i, j]

The scaling block size is B = K // K_s, where K_s is the K dimension of the scale tile. K must be divisible by K_s, and B must be one of the allowed values listed in the table below.

Parameters:
  • x (Tile) – LHS input, 2D or 3D [..., M, K].

  • x_scale (Tile) – Scale factors for x, shape [..., M, K_s]. All dimensions except K_s must match x exactly.

  • y (Tile) – RHS input, 2D or 3D [..., K, N].

  • y_scale (Tile) – Scale factors for y, shape [..., K_s, N]. All dimensions except K_s must match y exactly.

  • acc (Tile) – Accumulator [..., M, N].

Supported datatypes and scaling block sizes:

Input (x/y)

Scale

Acc/Out

B

f8e4m3fn, f8e5m2

f8e8m0fnu

f32

32

f4e2m1fn

f8e8m0fnu

f32

16, 32

f4e2m1fn

f8e4m3fn

f32

16

Batch dimensions of x and y are broadcast against each other (same as mma()). x_scale’s batch dimension must match x’s batch exactly, and y_scale’s batch dimension must match y’s batch exactly; both are then broadcast to the output batch shape.

Return type:

Tile

Examples

Basic usage.

tx = ct.ones((2, 64), ct.float8_e4m3fn)
sx = ct.full((2, 2), 2.0, ct.float8_e8m0fnu)
ty = ct.ones((64, 2), ct.float8_e4m3fn)
sy = ct.full((2, 2), 2.0, ct.float8_e8m0fnu)
acc = ct.zeros((2, 2), ct.float32)
tz = ct.mma_scaled(tx, sx, ty, sy, acc)
print(f"{tz:.1f}")
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    tx = ct.ones((2, 64), ct.float8_e4m3fn)
    sx = ct.full((2, 2), 2.0, ct.float8_e8m0fnu)
    ty = ct.ones((64, 2), ct.float8_e4m3fn)
    sy = ct.full((2, 2), 2.0, ct.float8_e8m0fnu)
    acc = ct.zeros((2, 2), ct.float32)
    tz = ct.mma_scaled(tx, sx, ty, sy, acc)
    print(f"{tz:.1f}")


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

[[256.0, 256.0], [256.0, 256.0]]

For best performance on sm_100.

m1, m2, k1 = 32, 4, 4

def swizzle_32_4_4(scale):
    '''
    Prepare the original scale tensor to align with the expected tmem layout.
    With the innermost dimensions being (m1=32, m2=4, k1=4), and the outer dimensions
    being (m0=(M // m1 * m2), k0=(K_s // k1)).

    Reference: PTX ISA tcgen05.mma scale factor layout.
    https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
    '''
    M, K_s = scale.shape
    m0 = M // (m1 * m2)
    k0 = K_s // k1
    scale = scale.reshape(m0, m2, m1, k0, k1).permute(0, 3, 2, 1, 4).contiguous()
    return scale.reshape(m0, k0, m1, m2 * k1)

def unswizzle_32_4_4(tile):
    '''
    Kernel-side inverse of ``swizzle_32_4_4``: take a tile loaded
    from the host swizzled scale tensor and recover the ``(M, K_s)``
    view that ``ct.mma_scaled`` expects.
    '''
    m0, k0, _, _ = tile.shape
    m1, m2, k1 = (32, 4, 4)
    return (tile.reshape((m0, k0, m1, m2, k1))
                .permute((0, 3, 2, 1, 4))
                .reshape((m0 * m1 * m2, k0 * k1)))

@ct.kernel
def kernel(X, X_scale, Y, Y_scale, Z,
           TM: ct.Constant[int], TN: ct.Constant[int], TK: ct.Constant[int]):
    x = ct.load(X, index=(0, 0), shape=(TM, TK))
    y = ct.load(Y, index=(0, 0), shape=(TN, TK)).transpose()
    x_s = ct.load(X_scale, index=(0, 0, 0, 0), shape=(1, 1, m1, m2 * k1))
    x_s = unswizzle_32_4_4(x_s)
    y_s = ct.load(Y_scale, index=(0, 0, 0, 0), shape=(1, 1, m1, m2 * k1))
    y_s = unswizzle_32_4_4(y_s).transpose()
    acc = ct.zeros((TM, TN), ct.float32)
    ct.store(Z, index=(0, 0),
             tile=ct.mma_scaled(x, x_s, y, y_s, acc))

M, N, K = 128, 128, 128
SCALE_BLOCK_SIZE = 32
K_s = K // SCALE_BLOCK_SIZE
TM, TN, TK = 128, 128, 128

X = torch.ones((M, K), device='cuda').to(torch.float8_e4m3fn)
Y = torch.ones((N, K), device='cuda').to(torch.float8_e4m3fn)
X_scale = torch.full((M, K_s), 2.0, device='cuda').to(torch.float8_e8m0fnu)
X_scale = swizzle_32_4_4(X_scale)
Y_scale = torch.full((N, K_s), 2.0, device='cuda').to(torch.float8_e8m0fnu)
Y_scale = swizzle_32_4_4(Y_scale)
Z = torch.zeros((M, N), dtype=torch.float32, device='cuda')
ct.launch(stream, (1, 1, 1), kernel, (X, X_scale, Y, Y_scale, Z, TM, TN, TK))
torch.cuda.synchronize()
print(Z.unique().tolist())
import cuda.tile as ct
import torch

torch.cuda.init()
stream = torch.cuda.current_stream()

m1, m2, k1 = 32, 4, 4

def swizzle_32_4_4(scale):
    '''
    Prepare the original scale tensor to align with the expected tmem layout.
    With the innermost dimensions being (m1=32, m2=4, k1=4), and the outer dimensions
    being (m0=(M // m1 * m2), k0=(K_s // k1)).

    Reference: PTX ISA tcgen05.mma scale factor layout.
    https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
    '''
    M, K_s = scale.shape
    m0 = M // (m1 * m2)
    k0 = K_s // k1
    scale = scale.reshape(m0, m2, m1, k0, k1).permute(0, 3, 2, 1, 4).contiguous()
    return scale.reshape(m0, k0, m1, m2 * k1)

def unswizzle_32_4_4(tile):
    '''
    Kernel-side inverse of ``swizzle_32_4_4``: take a tile loaded
    from the host swizzled scale tensor and recover the ``(M, K_s)``
    view that ``ct.mma_scaled`` expects.
    '''
    m0, k0, _, _ = tile.shape
    m1, m2, k1 = (32, 4, 4)
    return (tile.reshape((m0, k0, m1, m2, k1))
                .permute((0, 3, 2, 1, 4))
                .reshape((m0 * m1 * m2, k0 * k1)))

@ct.kernel
def kernel(X, X_scale, Y, Y_scale, Z,
           TM: ct.Constant[int], TN: ct.Constant[int], TK: ct.Constant[int]):
    x = ct.load(X, index=(0, 0), shape=(TM, TK))
    y = ct.load(Y, index=(0, 0), shape=(TN, TK)).transpose()
    x_s = ct.load(X_scale, index=(0, 0, 0, 0), shape=(1, 1, m1, m2 * k1))
    x_s = unswizzle_32_4_4(x_s)
    y_s = ct.load(Y_scale, index=(0, 0, 0, 0), shape=(1, 1, m1, m2 * k1))
    y_s = unswizzle_32_4_4(y_s).transpose()
    acc = ct.zeros((TM, TN), ct.float32)
    ct.store(Z, index=(0, 0),
             tile=ct.mma_scaled(x, x_s, y, y_s, acc))

M, N, K = 128, 128, 128
SCALE_BLOCK_SIZE = 32
K_s = K // SCALE_BLOCK_SIZE
TM, TN, TK = 128, 128, 128

X = torch.ones((M, K), device='cuda').to(torch.float8_e4m3fn)
Y = torch.ones((N, K), device='cuda').to(torch.float8_e4m3fn)
X_scale = torch.full((M, K_s), 2.0, device='cuda').to(torch.float8_e8m0fnu)
X_scale = swizzle_32_4_4(X_scale)
Y_scale = torch.full((N, K_s), 2.0, device='cuda').to(torch.float8_e8m0fnu)
Y_scale = swizzle_32_4_4(Y_scale)
Z = torch.zeros((M, N), dtype=torch.float32, device='cuda')
ct.launch(stream, (1, 1, 1), kernel, (X, X_scale, Y, Y_scale, Z, TM, TN, TK))
torch.cuda.synchronize()
print(Z.unique().tolist())

torch.cuda.synchronize()

Output

[512.0]