cuda.tile.load_advanced_indexing#

cuda.tile.load_advanced_indexing(
array,
indices,
/,
*,
padding_mode=PaddingMode.UNDETERMINED,
latency=None,
allow_tma=None,
)#

Loads a tile from non-contiguous slices of array.

indices is a tuple of length array.ndim. Exactly one entry must be a 1-D integer Tile (the sparse dim); every other entry must be a Slice (start, length) where start is a runtime element-space offset and length is a compile-time power-of-two tile size.

The sparse-dim tile contains element-space indices — each value selects one slice of the array along that dimension. Each dense-dim Slice describes a contiguous range [start, start + length). The resulting tile has shape (len_0, ..., len_{n-1}) where len_i is the index-tile length for the sparse dim or Slice.length for dense dims.

If the tile lies entirely outside the tiled view, the behavior is undefined.

Parameters:
  • array (Array) – Array to load from.

  • indices (tuple) – Length must equal array.ndim. Exactly one entry is a 1-D integer Tile (sparse dim); the rest are Slice objects (dense dims).

  • padding_mode (PaddingMode) – Fill value for OOB elements on both sparse and dense dims. Defaults to UNDETERMINED.

  • latency (int, optional) – DRAM traffic hint (1 = low, 10 = high).

  • allow_tma (bool, optional) – If False, TMA will not be used.

Returns:

Shape (len_0, ..., len_{n-1}) — sparse dim length equals the index-tile length; dense dim lengths equal the corresponding Slice.length values.

Return type:

Tile

Examples

@ct.kernel
def kernel(x, y, col_start):
    row_indices = ct.arange(4, dtype=ct.int32)
    tile = ct.load_advanced_indexing(x, (row_indices, ct.Slice(col_start, 4)),
                            padding_mode=ct.PaddingMode.ZERO)
    ct.store(y, (0, 0), tile)

x = torch.arange(64, device='cuda', dtype=torch.int32).reshape(8, 8)
y = torch.zeros(4, 4, device='cuda', dtype=torch.int32)
ct.launch(stream, (1,), kernel, (x, y, 2))
print(y.tolist())
import cuda.tile as ct
import torch

torch.cuda.init()
stream = torch.cuda.current_stream()

@ct.kernel
def kernel(x, y, col_start):
    row_indices = ct.arange(4, dtype=ct.int32)
    tile = ct.load_advanced_indexing(x, (row_indices, ct.Slice(col_start, 4)),
                            padding_mode=ct.PaddingMode.ZERO)
    ct.store(y, (0, 0), tile)

x = torch.arange(64, device='cuda', dtype=torch.int32).reshape(8, 8)
y = torch.zeros(4, 4, device='cuda', dtype=torch.int32)
ct.launch(stream, (1,), kernel, (x, y, 2))
print(y.tolist())

torch.cuda.synchronize()

Output

[[2, 3, 4, 5], [10, 11, 12, 13], [18, 19, 20, 21], [26, 27, 28, 29]]