cuda.tile.store_advanced_indexing#

cuda.tile.store_advanced_indexing(
array,
indices,
tile,
/,
*,
latency=None,
allow_tma=None,
)#

Stores a tile into non-contiguous slices of array.

Uses the same indices convention as load_advanced_indexing() — exactly one entry is a 1-D integer Tile (sparse dim) and the rest are Slice objects (dense dims). The tile’s shape must exactly match the shape implied by the indices.

If the tile lies entirely outside the tiled view, the behavior is undefined.

Parameters:
  • array (Array) – Array to store into.

  • indices (tuple) – Same convention as load_advanced_indexing().

  • tile (Tile) – Tile to store. Shape must exactly match the shape implied by indices.

  • latency (int, optional) – DRAM traffic hint (1 = low, 10 = high).

  • allow_tma (bool, optional) – If False, TMA will not be used.

Examples

@ct.kernel
def kernel(y):
    row_indices = ct.arange(4, dtype=ct.int32) + 1
    tile = ct.full((4, 4), 1, dtype=y.dtype)
    ct.store_advanced_indexing(y, (row_indices, ct.Slice(0, 4)), tile)

y = torch.zeros(6, 4, device='cuda', dtype=torch.int32)
ct.launch(stream, (1,), kernel, (y,))
print(y.tolist())
import cuda.tile as ct
import torch

torch.cuda.init()
stream = torch.cuda.current_stream()

@ct.kernel
def kernel(y):
    row_indices = ct.arange(4, dtype=ct.int32) + 1
    tile = ct.full((4, 4), 1, dtype=y.dtype)
    ct.store_advanced_indexing(y, (row_indices, ct.Slice(0, 4)), tile)

y = torch.zeros(6, 4, device='cuda', dtype=torch.int32)
ct.launch(stream, (1,), kernel, (y,))
print(y.tolist())

torch.cuda.synchronize()

Output

[[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 0, 0]]