cuda.tile.permute#
- cuda.tile.permute(x, /, axes)#
Permutes the axes of the input tile.
- Parameters:
x (Tile) – input tile.
axes (tuple[const int,...]) – the desired axes order.
- Return type:
Examples
tx = ct.arange(8, dtype=ct.int32).reshape((2, 2, 2)) ty = ct.permute(tx, (2, 0, 1)) print(ty)
import cuda.tile as ct import torch @ct.kernel def kernel(): tx = ct.arange(8, dtype=ct.int32).reshape((2, 2, 2)) ty = ct.permute(tx, (2, 0, 1)) print(ty) torch.cuda.init() ct.launch(torch.cuda.current_stream(), (1,), kernel, ()) torch.cuda.synchronize()
Output
[[[0, 2], [4, 6]], [[1, 3], [5, 7]]]