cuda.tile.expand_dims#
- cuda.tile.expand_dims(x, /, axis)#
Reshapes the tile by inserting a new axis of size 1 at given position.
This can also be done via the NumPy-style syntax: x[:, None] or x[np.newaxis, :]
- Parameters:
x (Tile) – input tile.
axis (const int) – axis to expand the tile dimension.
- Return type:
Examples
tx = ct.arange(4, dtype=ct.int32) print(ct.expand_dims(tx, 0)) print(tx[:, None])
import cuda.tile as ct import torch @ct.kernel def kernel(): tx = ct.arange(4, dtype=ct.int32) print(ct.expand_dims(tx, 0)) print(tx[:, None]) torch.cuda.init() ct.launch(torch.cuda.current_stream(), (1,), kernel, ()) torch.cuda.synchronize()
Output
[[0, 1, 2, 3]] [[0], [1], [2], [3]]