cuda.tile.num_blocks#
- cuda.tile.num_blocks(axis)#
Gets the number of blocks along the axis.
- Parameters:
axis (const int) – The axis of the block index space. Possible values are 0, 1, 2.
- Return type:
int32
Examples
@ct.kernel def kernel(): bidx = ct.bid(0) bidy = ct.bid(1) bidz = ct.bid(2) nx = ct.num_blocks(0) ny = ct.num_blocks(1) nz = ct.num_blocks(2) master_block = (bidx == 0 and bidy == 0 and bidz == 0) if master_block: print(f"Number of tile blocks: {(nx, ny, nz)}") ct.launch(stream, (2, 3, 4), kernel, ())
import cuda.tile as ct import torch torch.cuda.init() stream = torch.cuda.current_stream() @ct.kernel def kernel(): bidx = ct.bid(0) bidy = ct.bid(1) bidz = ct.bid(2) nx = ct.num_blocks(0) ny = ct.num_blocks(1) nz = ct.num_blocks(2) master_block = (bidx == 0 and bidy == 0 and bidz == 0) if master_block: print(f"Number of tile blocks: {(nx, ny, nz)}") ct.launch(stream, (2, 3, 4), kernel, ()) torch.cuda.synchronize()
Output
Number of tile blocks: (2, 3, 4)