cuda.tile.matmul#
- cuda.tile.matmul(x, y, /)#
Performs matrix multiply on the given tiles.
- Parameters:
Supported input datatypes: [f16, bf16, f32, f64, tf32, f8e4m3fn, f8e5m2, i8, u8]
If x and y have different dtype, they will first be promoted to common dtype. The result dtype is the same as the promoted input types. Shape of x and y will be broadcasted to up until the last two axes.
- Return type:
Examples
2D x 2D.
x = ct.ones((2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) print(f"{x @ y:.1f}")
import cuda.tile as ct import torch @ct.kernel def kernel(): x = ct.ones((2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) print(f"{x @ y:.1f}") torch.cuda.init() ct.launch(torch.cuda.current_stream(), (1,), kernel, ()) torch.cuda.synchronize()
Output
[[4.0, 4.0], [4.0, 4.0]]
1D x 1D (dot product).
x = ct.ones(4, dtype=ct.float32) y = ct.ones(4, dtype=ct.float32) print(f"{x @ y:.1f}")
import cuda.tile as ct import torch @ct.kernel def kernel(): x = ct.ones(4, dtype=ct.float32) y = ct.ones(4, dtype=ct.float32) print(f"{x @ y:.1f}") torch.cuda.init() ct.launch(torch.cuda.current_stream(), (1,), kernel, ()) torch.cuda.synchronize()
Output
4.0
Batched: 3D x 2D with broadcast.
x = ct.ones((2, 2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) print(f"{x @ y:.1f}")
import cuda.tile as ct import torch @ct.kernel def kernel(): x = ct.ones((2, 2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) print(f"{x @ y:.1f}") torch.cuda.init() ct.launch(torch.cuda.current_stream(), (1,), kernel, ()) torch.cuda.synchronize()
Output
[[[4.0, 4.0], [4.0, 4.0]], [[4.0, 4.0], [4.0, 4.0]]]