cuda.tile.store_advanced_indexing#
- cuda.tile.store_advanced_indexing(
- array,
- indices,
- tile,
- /,
- *,
- latency=None,
- allow_tma=None,
Stores a tile into non-contiguous slices of array.
Uses the same
indicesconvention asload_advanced_indexing()— exactly one entry is a 1-D integerTile(sparse dim) and the rest areSliceobjects (dense dims). The tile’s shape must exactly match the shape implied by the indices.If the tile lies entirely outside the tiled view, the behavior is undefined.
- Parameters:
array (Array) – Array to store into.
indices (tuple) – Same convention as
load_advanced_indexing().tile (Tile) – Tile to store. Shape must exactly match the shape implied by
indices.latency (int, optional) – DRAM traffic hint (1 = low, 10 = high).
allow_tma (bool, optional) – If
False, TMA will not be used.
Examples
@ct.kernel def kernel(y): row_indices = ct.arange(4, dtype=ct.int32) + 1 tile = ct.full((4, 4), 1, dtype=y.dtype) ct.store_advanced_indexing(y, (row_indices, ct.Slice(0, 4)), tile) y = torch.zeros(6, 4, device='cuda', dtype=torch.int32) ct.launch(stream, (1,), kernel, (y,)) print(y.tolist())
import cuda.tile as ct import torch torch.cuda.init() stream = torch.cuda.current_stream() @ct.kernel def kernel(y): row_indices = ct.arange(4, dtype=ct.int32) + 1 tile = ct.full((4, 4), 1, dtype=y.dtype) ct.store_advanced_indexing(y, (row_indices, ct.Slice(0, 4)), tile) y = torch.zeros(6, 4, device='cuda', dtype=torch.int32) ct.launch(stream, (1,), kernel, (y,)) print(y.tolist()) torch.cuda.synchronize()
Output
[[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 0, 0]]
See also