cuda.tile.matmul#

cuda.tile.matmul(x, y, /)#

Performs matrix multiply on the given tiles.

Parameters:
  • x (Tile) – LHS of the matmul, 1D, 2D, or 3D.

  • y (Tile) – RHS of the matmul, 1D, 2D, or 3D.

Supported input datatypes: [f16, bf16, f32, f64, tf32, f8e4m3fn, f8e5m2, i8, u8]

If x and y have different dtype, they will first be promoted to common dtype. The result dtype is the same as the promoted input types. Shape of x and y will be broadcasted to up until the last two axes.

Return type:

Tile

Examples

2D x 2D.

x = ct.ones((2, 4), dtype=ct.float32)
y = ct.ones((4, 2), dtype=ct.float32)
print(f"{x @ y:.1f}")
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    x = ct.ones((2, 4), dtype=ct.float32)
    y = ct.ones((4, 2), dtype=ct.float32)
    print(f"{x @ y:.1f}")


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

[[4.0, 4.0], [4.0, 4.0]]

1D x 1D (dot product).

x = ct.ones(4, dtype=ct.float32)
y = ct.ones(4, dtype=ct.float32)
print(f"{x @ y:.1f}")
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    x = ct.ones(4, dtype=ct.float32)
    y = ct.ones(4, dtype=ct.float32)
    print(f"{x @ y:.1f}")


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

4.0

Batched: 3D x 2D with broadcast.

x = ct.ones((2, 2, 4), dtype=ct.float32)
y = ct.ones((4, 2), dtype=ct.float32)
print(f"{x @ y:.1f}")
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    x = ct.ones((2, 2, 4), dtype=ct.float32)
    y = ct.ones((4, 2), dtype=ct.float32)
    print(f"{x @ y:.1f}")


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

[[[4.0, 4.0], [4.0, 4.0]], [[4.0, 4.0], [4.0, 4.0]]]