cuda.tile.permute#

cuda.tile.permute(x, /, axes)#

Permutes the axes of the input tile.

Parameters:
  • x (Tile) – input tile.

  • axes (tuple[const int,...]) – the desired axes order.

Return type:

Tile

Examples

tx = ct.arange(8, dtype=ct.int32).reshape((2, 2, 2))
ty = ct.permute(tx, (2, 0, 1))
print(ty)
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    tx = ct.arange(8, dtype=ct.int32).reshape((2, 2, 2))
    ty = ct.permute(tx, (2, 0, 1))
    print(ty)


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

[[[0, 2], [4, 6]], [[1, 3], [5, 7]]]