cuda.tile.where#

cuda.tile.where(cond, x, y, /)#

Returns elements chosen from x or y depending on condition.

Parameters:
  • cond (Tile) – Boolean tile of shape S.

  • x (Tile) – Tile of shape S and dtype T, selected if cond is True.

  • y (Tile) – Tile of shape S and dtype T, selected if cond is False.

Return type:

Tile

Examples

cond = ct.arange(4, dtype=ct.int32) >= 2
x_true = ct.full((4,), 1, dtype=ct.int32)
x_false = ct.full((4,), -1, dtype=ct.int32)
y = ct.where(cond, x_true, x_false)
print(y)
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    cond = ct.arange(4, dtype=ct.int32) >= 2
    x_true = ct.full((4,), 1, dtype=ct.int32)
    x_false = ct.full((4,), -1, dtype=ct.int32)
    y = ct.where(cond, x_true, x_false)
    print(y)


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

[-1, -1, 1, 1]