cuda.tile.mma#
- cuda.tile.mma(x, y, /, acc, *, use_fast_acc=False)#
Matrix multiply-accumulate.
Computes
(x @ y) + accas a single operation (where@denotes matrix multiplication). Preserves the dtype of acc.- Parameters:
x (Tile) – LHS of the mma, 2D or 3D.
y (Tile) – RHS of the mma, 2D or 3D.
acc (Tile) – Accumulator of mma.
use_fast_acc (bool) – Enable fast accumulation mode, which trades accumulator precision for throughput. Requires fp8 input dtypes (
float8_e4m3fnorfloat8_e5m2). Currently only has an effect on Hopper GPUs; silently ignored on other architectures. Default:False(since CTK 13.3).
Supported datatypes:
Input
Acc/Output
f16
f16 or f32
bf16
f32
f32
f32
f64
f64
tf32
f32
f8e4m3fn
f16 or f32
f8e5m2
f16 or f32
[u|i]8
i32
If x and y have different dtype, they will NOT be promoted to common dtype. Shape of x and y will be broadcasted to up until the last two axes.
- Return type:
Examples
2D x 2D with accumulation.
x = ct.ones((2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) acc = ct.full((2, 2), 10.0, dtype=ct.float32) # (x @ y) + acc: each element = 1*4 + 10 = 14 print(f"{ct.mma(x, y, acc):.1f}")
import cuda.tile as ct import torch @ct.kernel def kernel(): x = ct.ones((2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) acc = ct.full((2, 2), 10.0, dtype=ct.float32) # (x @ y) + acc: each element = 1*4 + 10 = 14 print(f"{ct.mma(x, y, acc):.1f}") torch.cuda.init() ct.launch(torch.cuda.current_stream(), (1,), kernel, ()) torch.cuda.synchronize()
Output
[[14.0, 14.0], [14.0, 14.0]]
Batched: 3D x 2D with broadcast.
x = ct.ones((2, 2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) acc = ct.zeros((2, 2, 2), dtype=ct.float32) print(f"{ct.mma(x, y, acc):.1f}")
import cuda.tile as ct import torch @ct.kernel def kernel(): x = ct.ones((2, 2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) acc = ct.zeros((2, 2, 2), dtype=ct.float32) print(f"{ct.mma(x, y, acc):.1f}") torch.cuda.init() ct.launch(torch.cuda.current_stream(), (1,), kernel, ()) torch.cuda.synchronize()
Output
[[[4.0, 4.0], [4.0, 4.0]], [[4.0, 4.0], [4.0, 4.0]]]