cuda.tile.store#

cuda.tile.store(
array,
/,
index,
tile,
*,
order='C',
latency=None,
allow_tma=None,
memory_order=MemoryOrder.WEAK,
memory_scope=MemoryScope.NONE,
)#

Stores a tile value into the array at the index of its tile space.

The tile space is the result of partitioning the array into a grid of tiles with equal size defined by the shape of the tile.

For example, given a tile t of shape (tm, tn) and array of shape (M, N):

# tile `t` has shape (tm, tn)
ct.store(array, (i, j), t)

The above call to store will store elements according to:

array[i * tm + x, i * tn + y] = t[x, y]  (for 0<=x<tm, 0<=y<tn)

For a tile that partially extends beyond the array boundaries, out-of-bound elements are ignored. If the tile lies entirely outside the array, the behavior is undefined.

Parameters:
  • array (Array) – The array to store to.

  • index (tuple[int,...]) – An index in the tile space of array. shape is inferred from the tile argument.

  • tile (Tile) – The tile to store. The rank of the tile must match rank of the array, unless it is a scalar or 0d tile.

  • order ("C" or "F", or tuple[const int,...]) – Order of axis mapping. See load().

  • latency (int, optional) – A hint indicating how heavy DRAM traffic will be. It shall be an integer between 1 (low) and 10 (high). By default, the compiler will infer the latency.

  • allow_tma (bool, optional) – If False, the store will not use TMA. By default, TMA is allowed.

  • memory_order (MemoryOrder) – Memory ordering semantics for the store. Defaults to MemoryOrder.WEAK. Valid values: WEAK, RELAXED, RELEASE.

  • memory_scope (MemoryScope) – The scope of threads that participate in memory ordering. Only meaningful when memory_order is not WEAK.

Examples

Store into 1D array, partial out of bound store is ignored.

@ct.kernel
def kernel(x):
    tile = ct.ones((4,), dtype=x.dtype)
    ct.store(x, (0,), tile)
    ct.store(x, (1,), tile * 2)

x = torch.zeros(6, dtype=torch.int32, device='cuda')
ct.launch(stream, (1,), kernel, (x,))
print(x.tolist())
import cuda.tile as ct
import torch

torch.cuda.init()
stream = torch.cuda.current_stream()

@ct.kernel
def kernel(x):
    tile = ct.ones((4,), dtype=x.dtype)
    ct.store(x, (0,), tile)
    ct.store(x, (1,), tile * 2)

x = torch.zeros(6, dtype=torch.int32, device='cuda')
ct.launch(stream, (1,), kernel, (x,))
print(x.tolist())

torch.cuda.synchronize()

Output

[1, 1, 1, 1, 2, 2]

When storing with a scalar (0d tile), it is broadcasted to the rank of the array.

@ct.kernel
def kernel(x):
    ct.store(x, (0, 0), tile=0)
    ct.store(x, (0, 1), tile=1)
    ct.store(x, (1, 0), tile=2)
    ct.store(x, (1, 1), tile=3)

x = torch.zeros(4, dtype=torch.int32, device='cuda').reshape(2, 2)
ct.launch(stream, (1,), kernel, (x,))
print(x.tolist())
import cuda.tile as ct
import torch

torch.cuda.init()
stream = torch.cuda.current_stream()

@ct.kernel
def kernel(x):
    ct.store(x, (0, 0), tile=0)
    ct.store(x, (0, 1), tile=1)
    ct.store(x, (1, 0), tile=2)
    ct.store(x, (1, 1), tile=3)

x = torch.zeros(4, dtype=torch.int32, device='cuda').reshape(2, 2)
ct.launch(stream, (1,), kernel, (x,))
print(x.tolist())

torch.cuda.synchronize()

Output

[[0, 1], [2, 3]]