cuda.tile.sum#

cuda.tile.sum(
x,
/,
axis=None,
*,
keepdims=False,
rounding_mode=None,
flush_to_zero=False,
)#

Performs sum reduction on tile along the axis.

Parameters:
  • x (Tile) – input tile.

  • axis (None | const int | tuple[const int,...]) – the axis for reduction. The default, axis=None, will reduce all of the elements. For argmin and argmax, tuple of axis is not supported.

  • keepdims (const bool) – If true, preserves the number of dimension from the input tile. rounding_mode (RoundingMode): The rounding mode for the operation, only supported for float types, default is RoundingMode.RN when applicable. flush_to_zero (const bool): If True, flushes subnormal inputs and results to sign-preserving zero, default is False.

Returns:

Examples

Reduce all axes.

tx = ct.arange(8, dtype=ct.int32).reshape((2, 4))
print("input:", tx)
print("reduced:", ct.sum(tx, None))
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    tx = ct.arange(8, dtype=ct.int32).reshape((2, 4))
    print("input:", tx)
    print("reduced:", ct.sum(tx, None))


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

input: [[0, 1, 2, 3], [4, 5, 6, 7]]
reduced: 28

Reduce axis 1 and keepdims.

tx = ct.arange(8, dtype=ct.int32).reshape((2, 4))
print("input:", tx)
print("reduced:", ct.sum(tx, 1, keepdims=True))
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    tx = ct.arange(8, dtype=ct.int32).reshape((2, 4))
    print("input:", tx)
    print("reduced:", ct.sum(tx, 1, keepdims=True))


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

input: [[0, 1, 2, 3], [4, 5, 6, 7]]
reduced: [[6], [22]]

Reduce axes (1, 2).

tx = ct.arange(8, dtype=ct.int32).reshape((2, 2, 2))
print("input:", tx)
print("reduced:", ct.sum(tx, (1, 2)))
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    tx = ct.arange(8, dtype=ct.int32).reshape((2, 2, 2))
    print("input:", tx)
    print("reduced:", ct.sum(tx, (1, 2)))


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

input: [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
reduced: [6, 22]

Return type:

Tile