Matrix Multiplication
Matrix multiplication of tiles is supported through the ct::matmul() and ct::mma() functions.
The ct::mma function accepts an accumulator argument that is added to the multiplication result.
These functions can perform batched matrix multiplies along a third dimension when the arguments have rank \(3\).
The result of the matrix multiplication operations is an approximation of an infinitely precise mathematical computation. The exhaustive numerical behavior for these operation is not currently specified. Information about the error bounds, rounding modes, and the behavior of non-finite, signed zero, and subnormal arguments may be absent.
Additionally, there are unspecified scenarios wherein overflow may occur while performing matrix multiplication of integral operands. In these scenarios, the result of the operation is unspecified.
cuda::tiles::mma_compatible
-
template<typename L, typename R, typename A>
concept mma_compatible = ct::numeric_tile<L> && ct::numeric_tile<R> && ct::numeric_tile<A> && /* atomic constraint */;
-
Indicates whether the types \(L\) (the left matrix), \(R\) (the right matrix) and \(A\) (the accumulator matrix) can participate in a matrix multiply accumulate operation.
The types \(L\), \(R\) and \(A\) can participate in a matrix multiply accumulate if
The element types of \(L\) and \(R\) are integral scalars of the same bitwidth or they are floating point scalars of the same conversion rank
-
The element types of \(L\), \(R\) and \(A\) are are selected from a row in the following table:
Element Type of \(L\) and \(R\)
Element type of \(A\)
\(8\) bit integral types
\(32\) bit signed integral types
__nv_fp8_e4m3__nv_fp8_e5m2__half
__halffloat
__nv_bfloat16__nv_tf32float
float
double
double
All of \(L\), \(R\) and \(A\) have the same rank and the rank is either \(2\) or \(3\).
-
In the case where the rank is \(2\), the shapes of \(L\), \(R\) and \(A\) are taken from the table for some integers \(N\), \(K\) and \(M\):
Shape of \(L\)
Shape of \(R\)
Shape of \(A\)
\(N \times K\)
\(K \times M\)
\(N \times M\)
-
In the case where the rank is \(3\), the shapes of \(L\), \(R\) and \(A\) are taken from the table for some integers \(A\), \(B\), \(C\), \(N\), \(K\) and \(M\):
Shape of \(L\)
Shape of \(R\)
Shape of \(A\)
\(A \times N \times K\)
\(B \times K \times M\)
\(C \times N \times M\)
Additionally \(A\) and \(B\) satisfy the following predicates:
\[\begin{split}\begin{cases} A = C & \text{or} \\ A = 1 & \end{cases} \quad \quad \begin{cases} B = C & \text{or} \\ B = 1 & \end{cases}\end{split}\]
cuda::tiles::matmul_compatible
-
template<typename L, typename R>
concept matmul_compatible = ct::numeric_tile<Lhs> && ct::numeric_tile<Rhs> && /* atomic constraint */;
-
Indicates whether the types \(L\) and \(R\) may participate in a matrix multiply operation.
The types \(L\) and \(R\) may participate in a matrix multiply operation if:
The element types of \(L\) and \(R\) are integral scalars whose bitwidth is \(8\) or they are floating point scalars of the same conversion rank.
\(L\) and \(R\) have the same rank and that rank is either \(2\) or \(3\).
-
In the case where the rank is \(2\), the shapes of \(L\) and \(R\) are taken from the table below for some integers \(N\), \(K\) and \(M\):
Shape of \(L\)
Shape of \(R\)
\(N \times K\)
\(K \times M\)
The hypothetical
ct::extentsspecialization of dimensions \(N \times M\) must be tile compatible. -
In the case where the rank is \(3\), the shapes of \(L\) and \(R\) are taken from the table below for some integers \(A\), \(B\), \(N\), \(K\) and \(M\):
Shape of \(L\)
Shape of \(R\)
\(A \times N \times K\)
\(B \times K \times M\)
Additionally, either \(A = B\) or at least one of them is equal to \(1\). The hypothetical
ct::extentsspecialization whose dimensions are \(\operatorname{max}(A, B) \times N \times M\) must be tile compatible.
cuda::tiles::matmul_result_t
-
template<ct::numeric_tile L, ct::numeric_tile R>
requires ct::matmul_compatible<L, R>
using matmul_result_t = /* see below */;
-
Yields the result type of performing matrix multiplication on operands of types \(L\) and \(R\)
The result type is a a specialization of
ct::tile. The element type of the result is determined by the element types of \(L\) and \(R\) according to the following table:\(L\) and \(R\) element type
Result Element Type
8 bit integral types
int32_t__nv_fp8_e4m3__nv_fp8_e5m2__half
__half__nv_bfloat16__nv_tf32float
floatdouble
doubleThe shape of the result depends on the rank of \(L\) and \(R\):
If \(L\) and \(R\) have rank is \(2\) with shapes \(N \times K\) and \(K \times M\) respectively, the result shape is \(N \times M\).
If \(L\) and \(R\) have rank \(3\) with shapes \(A \times N \times K\) and \(B \times K \times M\), the result shape is \(C \times N \times M\) where \(C\) is the larger of \(A\) and \(B\).
Note
The result element type for matmul cannot be configured. To use a different element type, use
ct::mma()with a zeroed accumulator.
cuda::tiles::mma
-
template<ct::numeric_tile L, ct::numeric_tile R, ct::numeric_tile A>
requires ct::mma_compatible<L, R, A>
__tile__ A mma(L lhs, R rhs, A acc) noexcept;
-
Performs matrix multiply on
lhsandrhsand adds the result toacc. For rank \(3\) arguments, the matrix multiply accumulate is performed for each set of matrices along the first dimension of the operands.If the operands have rank \(3\), the
lhsundergoes broadcast conversion to the shape \(A_0 \times L_1 \times L_2\) andrhsundergoes broadcast conversion to the shape \(A_0 \times R_1 \times R_2\).Let \(a\), \(b\) and \(c\) denote the converted operands
lhs,rhs, andaccrespectively.For rank \(2\) arguments, the result is an approximation of the matrix \(r\) determined by
\[r(i_0, i_1) = \sum_{k = 0}^{k < L_1} a(i_0, k) \cdot b(k, i_1) + c(i_0, i_1)\]For rank \(3\) arguments, the result is an approximation of the matrix \(r\) determined by
\[r(i_0, i_1, i_2) = \sum_{k = 0}^{k < L_2} a(i_0, i_1, k) \cdot b(i_0, k, i_2) + c(i_0, i_1, i_2)\]Example
The following example shows a matrix multiply accumulate computation for rank \(2\) arguments:
namespace ct = ::cuda::tiles; using i32x2x4 = ct::tile<int, ct::shape<2, 4>>; using i32x4x2 = ct::tile<int, ct::shape<4, 2>>; using i32x2x2 = ct::tile<int, ct::shape<2, 2>>; auto lhs = ct::element_cast<float>(ct::iota<i32x2x4>()); auto rhs = ct::element_cast<float>(ct::iota<i32x4x2>()); auto acc = ct::element_cast<float>(ct::iota<i32x2x2>()); auto r = ct::mma(lhs, rhs, acc);
\[\begin{split}\begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \times \begin{pmatrix} 0 & 1 \\ 2 & 3 \\ 4 & 5 \\ 6 & 7 \end{pmatrix} + \begin{pmatrix} 0 & 1 \\ 2 & 3 \end{pmatrix} \rightarrow \begin{pmatrix} 28 & 35 \\ 78 & 101 \end{pmatrix}\end{split}\]Example
The following example shows a batched matrix multiply accumulate computation for rank \(3\) arguments:
namespace ct = ::cuda::tiles; using i32x1x2x4 = ct::tile<int, ct::shape<1, 2, 4>>; using i32x1x4x2 = ct::tile<int, ct::shape<1, 4, 2>>; using i32x1x2x2 = ct::tile<int, ct::shape<1, 2, 2>>; auto lhs = ct::element_cast<float>( ct::cat<0>(ct::iota<i32x1x2x4>(), ct::iota<i32x1x2x4>())); auto rhs = ct::element_cast<float>( ct::cat<0>(ct::iota<i32x1x4x2>(), -ct::iota<i32x1x4x2>())); auto acc = ct::element_cast<float>( ct::cat<0>(ct::iota<i32x1x2x2>(), -ct::iota<i32x1x2x2>())); auto r = ct::mma(lhs, rhs, acc);
\[\begin{split}\begin{pmatrix} \begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \\ \begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \end{pmatrix} \times \begin{pmatrix} \begin{pmatrix} 0 & 1 \\ 2 & 3 \\ 4 & 5 \\ 6 & 7 \end{pmatrix} \\ \begin{pmatrix} 0 & -1 \\ -2 & -3 \\ -4 & -5 \\ -6 & -7 \end{pmatrix} \end{pmatrix} + \begin{pmatrix} \begin{pmatrix} 0 & 1 \\ 2 & 3 \end{pmatrix} \\ \begin{pmatrix} 0 & -1 \\ -2 & -3 \end{pmatrix} \end{pmatrix} \rightarrow \begin{pmatrix} \begin{pmatrix} 28 & 35 \\ 78 & 101 \end{pmatrix} \\ \begin{pmatrix} -28 & -35 \\ -78 & -101 \end{pmatrix} \end{pmatrix}\end{split}\]
cuda::tiles::matmul
-
template<ct::numeric_tile L, ct::numeric_tile R>
requires ct::matmul_compatible<L, R>
__tile__ ct::matmul_result_t<L, R> matmul(L lhs, R rhs) noexcept;
-
Performs matrix multiply on
lhsandrhs. For rank \(3\) arguments, the matrix multiply is performed for each set of matrices along the first dimension of the operands.If the operands have rank \(3\), the
lhsundergoes broadcast conversion to the shape \(\operatorname{max}(L_0, R_0) \times L_1 \times L_2\) andrhsundergoes broadcast conversion to the shape \(\operatorname{max}(L_0, R_0) \times R_1 \times R_2\).Let \(a\) and \(b\) denote the converted operands
lhsandrhsrespectively.For rank \(2\) arguments, the result is an approximation of the matrix \(r\) determined by
\[r(i_0, i_1) = \sum_{k = 0}^{k < L_1} a(i_0, k) \cdot b(k, i_1)\]For rank \(3\) arguments, the result is an approximation of the matrix \(r\) determined by
\[r(i_0, i_1, i_2) = \sum_{k = 0}^{k < L_2} a(i_0, i_1, k) \cdot b(i_0, k, i_2)\]Example
The following example shows a matrix multiply computation for rank \(2\) arguments:
namespace ct = ::cuda::tiles; using i32x2x4 = ct::tile<int, ct::shape<2, 4>>; using i32x4x2 = ct::tile<int, ct::shape<4, 2>>; auto lhs = ct::element_cast<float>(ct::iota<i32x2x4>()); auto rhs = ct::element_cast<float>(ct::iota<i32x4x2>()); auto r = ct::matmul(lhs, rhs);
\[\begin{split}\begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \times \begin{pmatrix} 0 & 1 \\ 2 & 3 \\ 4 & 5 \\ 6 & 7 \end{pmatrix} \rightarrow \begin{pmatrix} 28 & 34 \\ 76 & 98 \end{pmatrix}\end{split}\]Example
The following example shows a batched matrix multiply computation for rank \(3\) arguments:
namespace ct = ::cuda::tiles; using i32x1x2x4 = ct::tile<int, ct::shape<1, 2, 4>>; using i32x1x4x2 = ct::tile<int, ct::shape<1, 4, 2>>; auto lhs = ct::element_cast<float>(ct::cat<0>(ct::iota<i32x1x2x4>(), ct::iota<i32x1x2x4>())); auto rhs = ct::element_cast<float>(ct::cat<0>(ct::iota<i32x1x4x2>(), -ct::iota<i32x1x4x2>())); auto r = ct::matmul(lhs, rhs);
\[\begin{split}\begin{pmatrix} \begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \\ \begin{pmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ \end{pmatrix} \end{pmatrix} \times \begin{pmatrix} \begin{pmatrix} 0 & 1 \\ 2 & 3 \\ 4 & 5 \\ 6 & 7 \end{pmatrix} \\ \begin{pmatrix} 0 & -1 \\ -2 & -3 \\ -4 & -5 \\ -6 & -7 \end{pmatrix} \end{pmatrix} \rightarrow \begin{pmatrix} \begin{pmatrix} 28 & 34 \\ 76 & 98 \end{pmatrix} \\ \begin{pmatrix} -28 & -34 \\ -76 & -98 \end{pmatrix} \end{pmatrix}\end{split}\]