cuda.tile.where#
- cuda.tile.where(cond, x, y, /)#
Returns elements chosen from x or y depending on condition.
- Parameters:
- Return type:
Examples
cond = ct.arange(4, dtype=ct.int32) >= 2 x_true = ct.full((4,), 1, dtype=ct.int32) x_false = ct.full((4,), -1, dtype=ct.int32) y = ct.where(cond, x_true, x_false) print(y)
import cuda.tile as ct import torch @ct.kernel def kernel(): cond = ct.arange(4, dtype=ct.int32) >= 2 x_true = ct.full((4,), 1, dtype=ct.int32) x_false = ct.full((4,), -1, dtype=ct.int32) y = ct.where(cond, x_true, x_false) print(y) torch.cuda.init() ct.launch(torch.cuda.current_stream(), (1,), kernel, ()) torch.cuda.synchronize()
Output
[-1, -1, 1, 1]