Matrix Multiplication

Matrix multiplication of tiles is supported through the ct::matmul() and ct::mma() functions. The ct::mma function accepts an accumulator argument that is added to the multiplication result. These functions can perform batched matrix multiplies along a third dimension when the arguments have rank \(3\).

The result of the matrix multiplication operations is an approximation of an infinitely precise mathematical computation. The exhaustive numerical behavior for these operation is not currently specified. Information about the error bounds, rounding modes, and the behavior of non-finite, signed zero, and subnormal arguments may be absent.

Additionally, there are unspecified scenarios wherein overflow may occur while performing matrix multiplication of integral operands. In these scenarios, the result of the operation is unspecified.

cuda::tiles::mma_compatible

template<typename L, typename R, typename A>
concept mma_compatible = ct::numeric_tile<L> && ct::numeric_tile<R> && ct::numeric_tile<A> && /* atomic constraint */;

Indicates whether the types \(L\) (the left matrix), \(R\) (the right matrix) and \(A\) (the accumulator matrix) can participate in a matrix multiply accumulate operation.

The types \(L\), \(R\) and \(A\) can participate in a matrix multiply accumulate if

  1. The element types of \(L\) and \(R\) are integral scalars of the same bitwidth or they are floating point scalars of the same conversion rank

  2. The element types of \(L\), \(R\) and \(A\) are are selected from a row in the following table:

    Element Type of \(L\) and \(R\)

    Element type of \(A\)

    \(8\) bit integral types

    \(32\) bit signed integral types

    • __nv_fp8_e4m3

    • __nv_fp8_e5m2

    • __half

    • __half

    • float

    • __nv_bfloat16

    • __nv_tf32

    • float

    • float

    • double

    • double

  3. All of \(L\), \(R\) and \(A\) have the same rank and the rank is either \(2\) or \(3\).

  4. In the case where the rank is \(2\), the shapes of \(L\), \(R\) and \(A\) are taken from the table for some integers \(N\), \(K\) and \(M\):

    Shape of \(L\)

    Shape of \(R\)

    Shape of \(A\)

    \(N \times K\)

    \(K \times M\)

    \(N \times M\)

  5. In the case where the rank is \(3\), the shapes of \(L\), \(R\) and \(A\) are taken from the table for some integers \(A\), \(B\), \(C\), \(N\), \(K\) and \(M\):

    Shape of \(L\)

    Shape of \(R\)

    Shape of \(A\)

    \(A \times N \times K\)

    \(B \times K \times M\)

    \(C \times N \times M\)

    Additionally \(A\) and \(B\) satisfy the following predicates:

    \[\begin{split}\begin{cases} A = C & \text{or} \\ A = 1 & \end{cases} \quad \quad \begin{cases} B = C & \text{or} \\ B = 1 & \end{cases}\end{split}\]

cuda::tiles::matmul_compatible

template<typename L, typename R>
concept matmul_compatible = ct::numeric_tile<Lhs> && ct::numeric_tile<Rhs> && /* atomic constraint */;

Indicates whether the types \(L\) and \(R\) may participate in a matrix multiply operation.

The types \(L\) and \(R\) may participate in a matrix multiply operation if:

  1. The element types of \(L\) and \(R\) are integral scalars whose bitwidth is \(8\) or they are floating point scalars of the same conversion rank.

  2. \(L\) and \(R\) have the same rank and that rank is either \(2\) or \(3\).

  3. In the case where the rank is \(2\), the shapes of \(L\) and \(R\) are taken from the table below for some integers \(N\), \(K\) and \(M\):

    Shape of \(L\)

    Shape of \(R\)

    \(N \times K\)

    \(K \times M\)

    The hypothetical ct::extents specialization of dimensions \(N \times M\) must be tile compatible.

  4. In the case where the rank is \(3\), the shapes of \(L\) and \(R\) are taken from the table below for some integers \(A\), \(B\), \(N\), \(K\) and \(M\):

    Shape of \(L\)

    Shape of \(R\)

    \(A \times N \times K\)

    \(B \times K \times M\)

    Additionally, either \(A = B\) or at least one of them is equal to \(1\). The hypothetical ct::extents specialization whose dimensions are \(\operatorname{max}(A, B) \times N \times M\) must be tile compatible.

cuda::tiles::matmul_result_t

template<ct::numeric_tile L, ct::numeric_tile R>
requires ct::matmul_compatible<L, R>
using matmul_result_t = /* see below */;

Yields the result type of performing matrix multiplication on operands of types \(L\) and \(R\)

The result type is a a specialization of ct::tile. The element type of the result is determined by the element types of \(L\) and \(R\) according to the following table:

\(L\) and \(R\) element type

Result Element Type

8 bit integral types

int32_t

  • __nv_fp8_e4m3

  • __nv_fp8_e5m2

  • __half

__half

  • __nv_bfloat16

  • __nv_tf32

  • float

float

  • double

double

The shape of the result depends on the rank of \(L\) and \(R\):

  1. If \(L\) and \(R\) have rank is \(2\) with shapes \(N \times K\) and \(K \times M\) respectively, the result shape is \(N \times M\).

  2. If \(L\) and \(R\) have rank \(3\) with shapes \(A \times N \times K\) and \(B \times K \times M\), the result shape is \(C \times N \times M\) where \(C\) is the larger of \(A\) and \(B\).

Note

The result element type for matmul cannot be configured. To use a different element type, use ct::mma() with a zeroed accumulator.

cuda::tiles::mma

template<ct::numeric_tile L, ct::numeric_tile R, ct::numeric_tile A>
requires ct::mma_compatible<L, R, A>
__tile__ A mma(L lhs, R rhs, A acc) noexcept;

Performs matrix multiply on lhs and rhs and adds the result to acc. For rank \(3\) arguments, the matrix multiply accumulate is performed for each set of matrices along the first dimension of the operands.

If the operands have rank \(3\), the lhs undergoes broadcast conversion to the shape \(A_0 \times L_1 \times L_2\) and rhs undergoes broadcast conversion to the shape \(A_0 \times R_1 \times R_2\).

Let \(a\), \(b\) and \(c\) denote the converted operands lhs, rhs, and acc respectively.

For rank \(2\) arguments, the result is an approximation of the matrix \(r\) determined by

\[r(i_0, i_1) = \sum_{k = 0}^{k < L_1} a(i_0, k) \cdot b(k, i_1) + c(i_0, i_1)\]

For rank \(3\) arguments, the result is an approximation of the matrix \(r\) determined by

\[r(i_0, i_1, i_2) = \sum_{k = 0}^{k < L_2} a(i_0, i_1, k) \cdot b(i_0, k, i_2) + c(i_0, i_1, i_2)\]

Example

The following example shows a matrix multiply accumulate computation for rank \(2\) arguments:

namespace ct = ::cuda::tiles;
using i32x2x4 = ct::tile<int, ct::shape<2, 4>>;
using i32x4x2 = ct::tile<int, ct::shape<4, 2>>;
using i32x2x2 = ct::tile<int, ct::shape<2, 2>>;

auto lhs = ct::element_cast<float>(ct::iota<i32x2x4>());
auto rhs = ct::element_cast<float>(ct::iota<i32x4x2>());
auto acc = ct::element_cast<float>(ct::iota<i32x2x2>());

auto r = ct::mma(lhs, rhs, acc);
\[\begin{split}\begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \times \begin{pmatrix} 0 & 1 \\ 2 & 3 \\ 4 & 5 \\ 6 & 7 \end{pmatrix} + \begin{pmatrix} 0 & 1 \\ 2 & 3 \end{pmatrix} \rightarrow \begin{pmatrix} 28 & 35 \\ 78 & 101 \end{pmatrix}\end{split}\]

Example

The following example shows a batched matrix multiply accumulate computation for rank \(3\) arguments:

namespace ct = ::cuda::tiles;
using i32x1x2x4 = ct::tile<int, ct::shape<1, 2, 4>>;
using i32x1x4x2 = ct::tile<int, ct::shape<1, 4, 2>>;
using i32x1x2x2 = ct::tile<int, ct::shape<1, 2, 2>>;

auto lhs = ct::element_cast<float>(
             ct::cat<0>(ct::iota<i32x1x2x4>(), ct::iota<i32x1x2x4>()));
auto rhs = ct::element_cast<float>(
             ct::cat<0>(ct::iota<i32x1x4x2>(), -ct::iota<i32x1x4x2>()));
auto acc = ct::element_cast<float>(
             ct::cat<0>(ct::iota<i32x1x2x2>(), -ct::iota<i32x1x2x2>()));

auto r = ct::mma(lhs, rhs, acc);
\[\begin{split}\begin{pmatrix} \begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \\ \begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \end{pmatrix} \times \begin{pmatrix} \begin{pmatrix} 0 & 1 \\ 2 & 3 \\ 4 & 5 \\ 6 & 7 \end{pmatrix} \\ \begin{pmatrix} 0 & -1 \\ -2 & -3 \\ -4 & -5 \\ -6 & -7 \end{pmatrix} \end{pmatrix} + \begin{pmatrix} \begin{pmatrix} 0 & 1 \\ 2 & 3 \end{pmatrix} \\ \begin{pmatrix} 0 & -1 \\ -2 & -3 \end{pmatrix} \end{pmatrix} \rightarrow \begin{pmatrix} \begin{pmatrix} 28 & 35 \\ 78 & 101 \end{pmatrix} \\ \begin{pmatrix} -28 & -35 \\ -78 & -101 \end{pmatrix} \end{pmatrix}\end{split}\]

cuda::tiles::matmul

template<ct::numeric_tile L, ct::numeric_tile R>
requires ct::matmul_compatible<L, R>
__tile__ ct::matmul_result_t<L, R> matmul(L lhs, R rhs) noexcept;

Performs matrix multiply on lhs and rhs. For rank \(3\) arguments, the matrix multiply is performed for each set of matrices along the first dimension of the operands.

If the operands have rank \(3\), the lhs undergoes broadcast conversion to the shape \(\operatorname{max}(L_0, R_0) \times L_1 \times L_2\) and rhs undergoes broadcast conversion to the shape \(\operatorname{max}(L_0, R_0) \times R_1 \times R_2\).

Let \(a\) and \(b\) denote the converted operands lhs and rhs respectively.

For rank \(2\) arguments, the result is an approximation of the matrix \(r\) determined by

\[r(i_0, i_1) = \sum_{k = 0}^{k < L_1} a(i_0, k) \cdot b(k, i_1)\]

For rank \(3\) arguments, the result is an approximation of the matrix \(r\) determined by

\[r(i_0, i_1, i_2) = \sum_{k = 0}^{k < L_2} a(i_0, i_1, k) \cdot b(i_0, k, i_2)\]

Example

The following example shows a matrix multiply computation for rank \(2\) arguments:

namespace ct = ::cuda::tiles;
using i32x2x4 = ct::tile<int, ct::shape<2, 4>>;
using i32x4x2 = ct::tile<int, ct::shape<4, 2>>;

auto lhs = ct::element_cast<float>(ct::iota<i32x2x4>());
auto rhs = ct::element_cast<float>(ct::iota<i32x4x2>());

auto r = ct::matmul(lhs, rhs);
\[\begin{split}\begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \times \begin{pmatrix} 0 & 1 \\ 2 & 3 \\ 4 & 5 \\ 6 & 7 \end{pmatrix} \rightarrow \begin{pmatrix} 28 & 34 \\ 76 & 98 \end{pmatrix}\end{split}\]

Example

The following example shows a batched matrix multiply computation for rank \(3\) arguments:

namespace ct = ::cuda::tiles;
using i32x1x2x4 = ct::tile<int, ct::shape<1, 2, 4>>;
using i32x1x4x2 = ct::tile<int, ct::shape<1, 4, 2>>;

auto lhs = ct::element_cast<float>(ct::cat<0>(ct::iota<i32x1x2x4>(), ct::iota<i32x1x2x4>()));
auto rhs = ct::element_cast<float>(ct::cat<0>(ct::iota<i32x1x4x2>(), -ct::iota<i32x1x4x2>()));

auto r = ct::matmul(lhs, rhs);
\[\begin{split}\begin{pmatrix} \begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \\ \begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \end{pmatrix} \times \begin{pmatrix} \begin{pmatrix} 0 & 1 \\ 2 & 3 \\ 4 & 5 \\ 6 & 7 \end{pmatrix} \\ \begin{pmatrix} 0 & -1 \\ -2 & -3 \\ -4 & -5 \\ -6 & -7 \end{pmatrix} \end{pmatrix} \rightarrow \begin{pmatrix} \begin{pmatrix} 28 & 34 \\ 76 & 98 \end{pmatrix} \\ \begin{pmatrix} -28 & -34 \\ -76 & -98 \end{pmatrix} \end{pmatrix}\end{split}\]