cuda.tile.load_advanced_indexing#
- cuda.tile.load_advanced_indexing(
- array,
- indices,
- /,
- *,
- padding_mode=PaddingMode.UNDETERMINED,
- latency=None,
- allow_tma=None,
Loads a tile from non-contiguous slices of array.
indicesis a tuple of lengtharray.ndim. Exactly one entry must be a 1-D integerTile(the sparse dim); every other entry must be aSlice(start, length)wherestartis a runtime element-space offset andlengthis a compile-time power-of-two tile size.The sparse-dim tile contains element-space indices — each value selects one slice of the array along that dimension. Each dense-dim
Slicedescribes a contiguous range[start, start + length). The resulting tile has shape(len_0, ..., len_{n-1})wherelen_iis the index-tile length for the sparse dim orSlice.lengthfor dense dims.If the tile lies entirely outside the tiled view, the behavior is undefined.
- Parameters:
array (Array) – Array to load from.
indices (tuple) – Length must equal
array.ndim. Exactly one entry is a 1-D integerTile(sparse dim); the rest areSliceobjects (dense dims).padding_mode (PaddingMode) – Fill value for OOB elements on both sparse and dense dims. Defaults to
UNDETERMINED.latency (int, optional) – DRAM traffic hint (1 = low, 10 = high).
allow_tma (bool, optional) – If
False, TMA will not be used.
- Returns:
Shape
(len_0, ..., len_{n-1})— sparse dim length equals the index-tile length; dense dim lengths equal the correspondingSlice.lengthvalues.- Return type:
Examples
@ct.kernel def kernel(x, y, col_start): row_indices = ct.arange(4, dtype=ct.int32) tile = ct.load_advanced_indexing(x, (row_indices, ct.Slice(col_start, 4)), padding_mode=ct.PaddingMode.ZERO) ct.store(y, (0, 0), tile) x = torch.arange(64, device='cuda', dtype=torch.int32).reshape(8, 8) y = torch.zeros(4, 4, device='cuda', dtype=torch.int32) ct.launch(stream, (1,), kernel, (x, y, 2)) print(y.tolist())
import cuda.tile as ct import torch torch.cuda.init() stream = torch.cuda.current_stream() @ct.kernel def kernel(x, y, col_start): row_indices = ct.arange(4, dtype=ct.int32) tile = ct.load_advanced_indexing(x, (row_indices, ct.Slice(col_start, 4)), padding_mode=ct.PaddingMode.ZERO) ct.store(y, (0, 0), tile) x = torch.arange(64, device='cuda', dtype=torch.int32).reshape(8, 8) y = torch.zeros(4, 4, device='cuda', dtype=torch.int32) ct.launch(stream, (1,), kernel, (x, y, 2)) print(y.tolist()) torch.cuda.synchronize()
Output
[[2, 3, 4, 5], [10, 11, 12, 13], [18, 19, 20, 21], [26, 27, 28, 29]]
See also