General Matrix Multiply Using cuBLASDx¶
In this introduction, we will perform a general matrix multiplication \(\mathbf{C}_{m\times n} = {\alpha} \times \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n} + {\beta} \times \mathbf{C}_{m\times n}\) using the cuBLASDx library. This section is based on the introduction_example.cu example shipped with cuBLASDx. See Examples section to check other cuBLASDx samples.
Defining GEMM Operation¶
The first step is defining the GEMM we want to perform. It is done by adding together cuBLASDx operators to create a GEMM description. The correctness of this type is evaluated at compile time every time new operator is added. A well-defined cuBLASDx GEMM routine description must include two parts:
Selected linear algebra routine. In this case that is matrix multiplication:
cublasdx::function::MM
.Valid and sufficient description of the inputs and outputs: the dimensions of matrices (
m
,n
,k
), the precision (half, float, double etc.), the data type (real or complex) and the data arrangement of matrices (row- or column-major).
To get a descriptor for \(\mathbf{C}_{m\times n} = {\alpha} \times \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n} + {\beta} \times \mathbf{C}_{m\times n}\) with m = n = k = 32
, we just need
to write the following lines:
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32 /* m */, 32 /* n */, 32 /* k */>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major /* A */, cublasdx::col_major /* B */>());
In order to encode the operation properties, cuBLASDx provides operators
Size,
Precision,
Type,
Function, and
Arrangement,
which can be combined by using the standard addition operator (+
).
Optionally, user can set alignments and leading dimensions for each matrix using Alignment and LeadingDimension, respectively. For leading dimensions, it is also possible to set them dynamically during the execution, however, it is worth noting it may have an effect on the performance.
Tip
cuBLASDx also supports matrices that can not simply be expressed by row- or column-major and leading dimensions. See simple_gemm_custom_layout.cu example.
To obtain a fully usable operation that executes GEMM on CUDA block level, we need to provide at least two additional pieces of information:
The first one is the SM Operator which indicates the targeted CUDA architecture on which we want to run the GEMM. Each GPU architecture is different, therefore each can use a different implementation and may require different CUDA block size for the best performance. In the introduction_example.cu example this is passed as template parameter, but in here we can assume we’re targeting Volta GPUs (
SM<700>()
).Finally, we use the Block Operator to show that the BLAS routine will be performed by multiple threads in a single CUDA block. At this point, cuBLASDx performs additional verifications to make sure provided description is valid and that it is possible to execute it on the requested architecture.
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block());
User can also specify the layout and the number of threads that will be performing the GEMM.
This is done with the BlockDim Operator.
Adding BlockDim<X, Y, Z>
means that the GEMM will only work correctly if a kernel is launched with block dimensions dim3(X1, Y1, Z1)
where
X1 >= X
, Y1 >= Y
, and Z1 >= Z
.
Detailed requirements can be found in the section dedicated to BlockDim operator.
If BlockDim
operator is not used, cuBLASDx will select preferred block size that can be obtained with GEMM::block_dim
.
Tip
If there is no need to set custom block dimensions, it is recommended not to use BlockDim
operator and rely on GEMM::block_dim
.
For more details, see Block Execute Method section, BlockDim Operator, and
Suggested Block Dim Trait.
For this sample, let’s assume we want to use a 1D CUDA thread block with 256 threads.
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block()
+ BlockDim<256>());
Executing GEMM¶
Class GEMM
which describes the matrix multiplication can be instantiated into object (or objects).
Forming the object has no computational cost, and should be seen as a handle.
The function descriptor object provides a compute method, execute(...)
that performs the requested function.
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block());
__global__ void gemm_kernel(double alpha, double *a, double *b, double beta, double *c) {
// Execute GEMM
GEMM().execute(/* What are the arguments? */);
}
It is assumed that all the matrices reside in the shared memory. It is up to the users to load the matrices from global to shared memory before calling the execution method. In the same way, users are responsible for saving the results.
Starting from cuBLASDx 0.2.0, the execute method takes tensors (cublasdx::tensor
) as inputs and outputs.
cublasdx::tensor
is an alias of a CuTe tensor (cute::Tensor),
which is a representation of a multidimensional array that hold
data in any kind of memory, including global memory, shared memory and register memory, and
a CuTe layout (cute::Layout) describing how elements are organized.
A typical structure of a GEMM
kernel is as follows:
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block());
// Type <a/b/c>_value_type is defined based on the GEMM description. Precision operator defines its numerical
// precision, and via Type operator user specifies if it is complex or real.
//
// In this case, a/b/c_value_type are all double since set precision is double, and type is real.
using a_value_type = typename GEMM::a_value_type;
using b_value_type = typename GEMM::b_value_type;
using c_value_type = typename GEMM::c_value_type;
__global__ void gemm_kernel(c_value_type alpha, a_value_type *a, b_value_type *b, c_value_type beta, c_value_type *c) {
extern __shared__ __align__(16) char smem[];
// Create global memory tensor
// a_global_tensor = (from a)
// b_global_tensor = (from b)
// c_global_tensor = (from c)
// Make shared memory tensor
// a_shared_tensor = (from smem)
// b_shared_tensor = (from smem + ...)
// c_shared_tensor = (from smem + ...)
// Load data from global memory tensor to shared memory tensor
// a_shared_tensor <-- a_global_tensor
// b_shared_tensor <-- b_global_tensor
// c_shared_tensor <-- c_global_tensor
// Execute GEMM
GEMM().execute(alpha, a_shared_tensor, b_shared_tensor, beta, c_shared_tensor);
__syncthreads();
// Store data from shared memory tensor to global memory tensor
// c_global_tensor <-- c_shared_tensor
}
As hinted by the comments, there are 4 steps.
Create global and shared memory tensors (see Tensor Creation).
Copy data from global memory tensors to shared memory tensors (see Copying Tensors).
(main step) Execute
GEMM
using the tensor APIs.Copy data from shared memory tensors to global memory tensors (see Copying Tensors).
Tensor Creation¶
To create tensors with global and shared memory, cuBLASDx provides a helper function cublasdx::make_tensor(...)
,
which works together with the layouts returned by the method get_layout_<gmem/smem>_<a/b/c>(...)
from the defined GEMM
object.
Both layouts take into account of the arrangements and shared memory layouts utilize leading dimensions information from the GEMM
type. For global memory layouts information regarding leading dimensions must be passed through an extra argument and otherwise it
will be inferred from the given problem size.
For creating shared memory tensors, we need pointers that point to shared memory slices for A
, B
and C
matrices.
The slice_shared_memory(...)
method provides the functionality.
template<class GEMM>
__global__ void gemm_kernel(GEMM::c_value_type alpha, GEMM::a_value_type *a, GEMM::b_value_type *b, GEMM::c_value_type beta, GEMM::c_value_type *c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b, smem_c] = GEMM::slice_shared_memory(smem); // smem_<a/b/c> are aligned to cublasdx::alignment_of<GEMM>::<a/b/c>
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
auto c_shared_tensor = cublasdx::make_tensor(smem_c, GEMM::get_layout_smem_c());
}
Tip
If there is no need to use plain row- or column-major layouts for shared memory, it is recommended to use layouts returned by GEMM::suggest_layout_smem_<a/b/c>(...)
as in many cases it will lead to better performance. See Suggested Shared Memory Layout.
auto [smem_a, smem_b, smem_c] = GEMM::slice_shared_memory(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::suggest_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::suggest_layout_smem_b());
auto c_shared_tensor = cublasdx::make_tensor(smem_c, GEMM::suggest_layout_smem_c());
For more details of the mentioned helper function and methods, see Tensor Creation, Suggested Shared Memory Layout and Shared Memory Slicing.
Copying Tensors¶
cuBLASDx offers a helper function, cublasdx::copy(...)
, that copies data between tensor objects.
All threads from the BlockDim Operator will participate in the copy.
The function takes into account of the given alignments and attempts to vectorize the load and store when possible.
It is recommended to use it for achieving better kernel performance. See Copying Tensors for more details.
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor); // <a/b/c>_shared_tensor, created from smem_<a/b/c>, is aligned to alignment::<a/b/c>
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy<GEMM, alignment::c>(c_global_tensor, c_shared_tensor);
cublasdx::copy_wait();
// Store data to global memory
cublasdx::copy<GEMM, alignment::c>(c_shared_tensor, c_global_tensor);
Launching GEMM Kernel¶
To launch a kernel executing the defined GEMM
we need to know the required block dimensions and the amount of shared memory needed for all
three matrices - A
, B
, C
. Elements in the matrix A
should be in a row-major format, and matrices B
and C
in a column-major format, accounting for leading dimensions.
#include <cublasdx.hpp>
using namespace cublasdx;
template<class GEMM>
__global__ void gemm_kernel(GEMM::c_value_type alpha, GEMM::a_value_type *a, GEMM::b_value_type *b, GEMM::c_value_type beta, GEMM::c_value_type *c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b, smem_c] = GEMM::slice_shared_memory(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
auto c_shared_tensor = cublasdx::make_tensor(smem_c, GEMM::get_layout_smem_c());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy<GEMM, alignment::c>(c_global_tensor, c_shared_tensor);
cublasdx::copy_wait();
// Execute GEMM
GEMM().execute(alpha, a_shared_tensor, b_shared_tensor, beta, c_shared_tensor);
__syncthreads();
// Store data to global memory
cublasdx::copy<GEMM, alignment::c>(c_shared_tensor, c_global_tensor);
}
// CUDA_CHECK_AND_EXIT - marco checks if function returns cudaSuccess; if not it prints the error code and exits the program
void introduction_example(value_type alpha, value_type *a, value_type *b, value_type beta, value_type *c) {
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ Function<function::MM>());
+ SM<700>()
+ Block());
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel<GEMM><<<1, GEMM::block_dim, GEMM::shared_memory_size>>>(alpha, a, b, beta, c);
CUDA_CHECK_AND_EXIT(cudaPeekAtLastError());
CUDA_CHECK_AND_EXIT(cudaDeviceSynchronize());
}
The required shared memory can be obtained using GEMM::shared_memory_size
. It accounts for any padding declared using LeadingDimension Operator.
For simplicity, in the example we allocate managed memory for device matrices, assume that Volta architecture is used, and don’t check CUDA error codes returned by CUDA API functions. Please check the full introduction_example.cu example, as well as others shipped with cuBLASDx, for more detailed code.
#include <cublasdx.hpp>
using namespace cublasdx;
template<class GEMM>
__global__ void gemm_kernel(GEMM::value_type alpha, GEMM::value_type *a, GEMM::value_type *b, GEMM::value_type beta, GEMM::value_type *c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b, smem_c] = GEMM::slice_shared_memory(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
auto c_shared_tensor = cublasdx::make_tensor(smem_c, GEMM::get_layout_smem_c());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy<GEMM, alignment::c>(c_global_tensor, c_shared_tensor);
cublasdx::copy_wait();
// Execute GEMM
GEMM().execute(alpha, a_shared_tensor, b_shared_tensor, beta, c_shared_tensor);
__syncthreads();
// Store data to global memory
cublasdx::copy<GEMM, alignment::c>(c_shared_tensor, c_global_tensor);
}
void introduction_example() {
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ Function<function::MM>()
+ SM<700>()
+ Block()
+ BlockDim<256>());
// Allocate managed memory for A, B, C matrices in one go
using value_type = typename example::uniform_value_type_t<GEMM>; // in the example A, B, C are of the same value_type
value_type* abc;
auto size = GEMM::a_size + GEMM::b_size + GEMM::c_size;
auto size_bytes = size * sizeof(value_type);
cudaMallocManaged(&abc, size_bytes);
// Generate data
for (size_t i = 0; i < size; i++) {
abc[i] = double(i) / size;
}
a_value_type* a = abc;
b_value_type* b = abc + GEMM::a_size;
c_value_type* c = abc + GEMM::a_size + GEMM::b_size;
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel<GEMM><<<1, GEMM::block_dim, GEMM::shared_memory_size>>>(1.0, a, b, 1.0, c);
cudaDeviceSynchronize();
cudaFree(abc);
}
It is important to notice that unlike the cuBLAS library cuBLASDx does not require moving data back to global memory after executing a BLAS operation. Nor does it require the input data to be loaded from global memory. Those properties can be a major performance advantage for certain use-cases. The list of possible optimizations includes but is not limited to:
Fusing BLAS routines with custom pre- and post-processing.
Fusing multiple BLAS operations together.
Fusing BLAS and FFT operations (using cuFFTDx) together.
Generating input matrices or parts of them.
Compilation¶
For instructions on how to compile programs with cuBLASDx see Quick Installation Guide.