cuda.tile.expand_dims#

cuda.tile.expand_dims(x, /, axis)#

Reshapes the tile by inserting a new axis of size 1 at given position.

This can also be done via the NumPy-style syntax: x[:, None] or x[np.newaxis, :]

Parameters:
  • x (Tile) – input tile.

  • axis (const int) – axis to expand the tile dimension.

Return type:

Tile

Examples

tx = ct.arange(4, dtype=ct.int32)
print(ct.expand_dims(tx, 0))
print(tx[:, None])
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    tx = ct.arange(4, dtype=ct.int32)
    print(ct.expand_dims(tx, 0))
    print(tx[:, None])


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

[[0, 1, 2, 3]]
[[0], [1], [2], [3]]