General Matrix Multiply Using cuBLASDx#
In this introduction, we will perform a general matrix multiplication using the cuBLASDx library. Three variants of this operation are exposed:
Shared memory API
: \(\mathbf{C}_{m\times n} = {\alpha} \times \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n} + {\beta} \times \mathbf{C}_{m\times n}\)Register API
With accumulator
: \(\mathbf{C}_{m\times n} = \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n} + \mathbf{C}_{m\times n}\)Without accumulator
: \(\mathbf{C}_{m\times n} = \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n}\)
This section is based on the introduction_example.cu example shipped with cuBLASDx. See Examples section to check other cuBLASDx samples.
Defining GEMM Operation#
The first step is defining the GEMM we want to perform. It is done by adding together cuBLASDx operators to create a GEMM description. The correctness of this type is evaluated at compile time every time new operator is added. A well-defined cuBLASDx GEMM routine description must include two parts:
Selected linear algebra routine. In this case that is matrix multiplication:
cublasdx::function::MM
.Valid and sufficient description of the inputs and outputs: the dimensions of matrices (
m
,n
,k
), the precision (half, float, double etc.), the data type (real or complex) and the data arrangement of matrices (row- or column-major).
To get a descriptor for any of the operations described by:
\(\mathbf{C}_{m\times n} = \left[ {\alpha} \ \times \ \right] \ \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n} \ \left[\ + {\beta} \times \mathbf{C}_{m\times n} \right]\)
with m = n = k = 32
, we just need to write the following lines:
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32 /* m */, 32 /* n */, 32 /* k */>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major /* A */, cublasdx::col_major /* B */>());
In order to encode the operation properties, cuBLASDx provides operators
Size,
Precision,
Type,
Function, and
Arrangement,
which can be combined by using the standard addition operator (+
).
Optionally, user can set alignments and leading dimensions for each matrix using Alignment and LeadingDimension, respectively. When using custom inputs different from compute types, the Alignment Operator must be set to appropriate values.
For leading dimensions, it is also possible to set them dynamically during the execution, however, it is worth noting it may have an effect on the performance.
Tip
cuBLASDx also supports matrices that can not simply be expressed by row- or column-major and leading dimensions. See simple_gemm_custom_layout.cu example.
To obtain a fully usable operation that executes GEMM on CUDA block level, we need to provide at least two additional pieces of information:
The first one is the SM Operator which indicates the targeted CUDA architecture on which we want to run the GEMM. Each GPU architecture is different, therefore each can use a different implementation and may require different CUDA block size for the best performance. In the introduction_example.cu example this is passed as template parameter, but in here we can assume we’re targeting Volta GPUs (
SM<700>()
).Finally, we use the Block Operator to show that the BLAS routine will be performed by multiple threads in a single CUDA block. At this point, cuBLASDx performs additional verifications to make sure provided description is valid and that it is possible to execute it on the requested architecture.
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block());
User can also specify the layout and the number of threads that will be performing the GEMM.
This is done with the BlockDim Operator.
Adding BlockDim<X, Y, Z>
means that the GEMM will only work correctly if a kernel is launched with block dimensions dim3(X1, Y1, Z1)
where
X1 >= X
, Y1 >= Y
, and Z1 >= Z
.
Detailed requirements can be found in the section dedicated to BlockDim operator.
If BlockDim
operator is not used, cuBLASDx will select preferred block size that can be obtained with GEMM::block_dim
.
Tip
If there is no need to set custom block dimensions, it is recommended not to use BlockDim
operator and rely on GEMM::block_dim
.
For more details, see Block Execute Method section, BlockDim Operator, and
Suggested Block Dim Trait.
For this sample, let’s assume we want to use a 1D CUDA thread block with 256 threads.
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block()
+ BlockDim<256>());
Executing GEMM#
Class GEMM
which describes the matrix multiplication can be instantiated into object (or objects).
Forming the object has no computational cost, and should be seen as a handle.
The function descriptor object provides compute methods, execute(...)
that perform the requested function.
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block());
__global__ void gemm_kernel(double alpha, double *a, double *b, double beta, double *c) {
// Execute GEMM
GEMM().execute(/* What are the arguments? */);
}
Starting from cuBLASDx 0.2.0, the execute method takes tensors (cublasdx::tensor
) as inputs and outputs.
cublasdx::tensor
is an alias of a CuTe tensor (cute::Tensor),
which is a representation of a multidimensional array that hold
data in any kind of memory, including global memory, shared memory and register memory, and
a CuTe layout (cute::Layout) describing how elements are organized.
Tensor Creation#
Tensor Partitioning#
Starting with cuBLASDx 0.3.0
and register fragment APIs, library offers new interfaces allowing for efficient
partitioning of Global and shared memory Tensors between threads taking part in GEMM, and subsequent modification
of these register tensors. An entry point for these operations is a Partitioner object tied to specific GEMM instance:
auto partitioner = GEMM::get_partitioner(); auto partitioner = GEMM::suggest_partitioner();
Such object allows to:
Create register fragment Accumulator for this GEMM
Map fragment indexes to Global Tensor indexes
Partition other Tensors like
C
to get their subtensorsApply predication to out-of-bounds elements and threads
Please refer to Partitioner And Register Fragment Tensors for more detail.
Warning
Register fragment can be used as accumulator ONLY with GEMM which was used to create it
Register fragment Accumulators#
A register fragment Accumulator is an array stored in thread local register file (RF) memory wrapped into a
cublasdx::tensor
with an opaque layout describing internal GEMM execution. As opposed to Global Memory
and shared memory tensors, this layout may not be arbitrary and can only be obtained from a partitioner object
(see Partitioner And Register Fragment Tensors).
Note
Register fragment is an opaque hierarachical tensor exposing a 1D tensor interface
Implementation details of specific layout of any register fragment are tied to GEMM execution,
but each fragment can be accessed with 1D indiced ranging from 0 to cublasdx::size(register_fragment)
Each register fragment Accumulator represents a fragment of a Global or Shared memory matrix, determined by index
of a thread holding it and GEMM
instance it has been created from. cuBLASDx exposes 2 ways of mapping memory from
thread-local index space to entire tensor index space:
Through manual index mapping utilities of a partitioner object
Through automatic copying functions, which have gather / scatter semantics.
See Copying register fragments and Copying registers tensors for more information.
To get a register fragment for a GEMM instance it’s enough to obtain a partitioner and use it to create an uninitialized accumulator:
auto partitioner = BLAS::get_partitioner();
auto c_fragment_accumulator = partitioner.make_accumulator_fragment();
// Now you can access it as a regular 1D tensor:
auto val_0 = c_fragment_accumulator(0);
Copying Tensors#
Copying register fragments#
To copy register fragment Accumulators with GEMM results cuBLASDx offers a helper function, cublasdx::copy_fragment(...)
,
responsible for performing loads and stores between a local tensor fragment and appropriate locations in global / shared tensor.
The function takes into account of the given alignments and attempts to vectorize the load and store when possible.
- This copy is a per-thread operation, and global / shared data partitioning is based on:
partitioner object containing appropriate GEMM execution details
thread index (contained in partitioner object)
Partitioner object offers many helper APIs allowing for significant flexibility in data operations. See Partitioner And Register Fragment Tensors for more details.
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
auto partitioner = GEMM::get_partitioner();
auto c_fragment_accumulator = partitioner.make_accumulator_fragment();
// Load data from global to registers
cublasdx::copy_fragment<alignment::a>(c_global_tensor, c_fragment_accumulator, partitioner);
// Load data from shared to registers
cublasdx::copy_fragment<alignment::a>(c_shared_tensor, c_fragment_accumulator, partitioner);
// Store data from registers to global
cublasdx::copy_fragment<alignment::a>(c_fragment_accumulator, c_global_tensor, partitioner);
// Store data from registers to shared
cublasdx::copy_fragment<alignment::a>(c_fragment_accumulator, c_shared_tensor, partitioner);
Accumulator register GEMM API#
A typical structure of a register Accumulation API GEMM
kernel is as follows:
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block());
// Type <a/b/c>_value_type is defined based on the GEMM description. Precision operator defines its numerical
// precision, and via Type operator user specifies if it is complex or real.
//
// In this case, a/b/c_value_type are all double since set precision is double, and type is real.
using a_value_type = typename GEMM::a_value_type;
using b_value_type = typename GEMM::b_value_type;
using c_value_type = typename GEMM::c_value_type;
__global__ void gemm_kernel_registers_accumulation(a_value_type *a, b_value_type *b, c_value_type *c) {
extern __shared__ __align__(16) char smem[];
// Create global memory tensor
// a_global_tensor = (from a)
// b_global_tensor = (from b)
// c_global_tensor = (from c)
// Make shared memory tensor
// a_shared_tensor = (from smem)
// b_shared_tensor = (from smem + ...)
// Load data from global memory tensor to shared memory tensor
// a_shared_tensor <-- a_global_tensor
// b_shared_tensor <-- b_global_tensor
// Make C register Accumulator fragment
// c_register_accumulator = (from GEMM)
// Load appropriate data from global memory tensor to register fragment tensor
// c_register_accumulator <- c_global_tensor
// Execute GEMM
GEMM().execute(a_shared_tensor, b_shared_tensor, c_register_accumulator);
__syncthreads();
// Store data from shared memory tensor to global memory tensor
// c_global_tensor <-- c_register_accumulator
}
This API is more involved, adding extra steps for C accumulator:
Create global and shared memory tensors (see Tensor Creation).
Copy data from global memory tensors to shared memory tensors (see Copying Tensors).
Create register memory
C
accumulator tensorCopy appropriate part of global input tensor
C
into register memory(main step) Execute
GEMM
using the tensor APIs.Copy data from register accumulator tensor to appropriate places in global memory tensor (see Copying Tensors).
After filling all these steps in with tensor creation and copying code, we get:
#include <cublasdx.hpp>
using namespace cublasdx;
template<class GEMM>
__global__ void gemm_kernel_registers_accumulation(const typename GEMM::a_value_type* a,
const typename GEMM::b_value_type* b,
typename GEMM::c_value_type* c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b] = cublasdx::slice_shared_memory_ab<GEMM>(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy_wait();
// Get default data partitioner
auto partitioner = GEMM::get_partitioner();
// Create register fragment Accumulator
auto c_register_fragment = partitioner.make_accumulator_fragment();
// Partition Global C for GEMM and load appropriate elements into register fragment
cublasdx::copy_fragment<alignment::c>(c_global_tensor, c_register_fragment, partitioner);
// Execute GEMM with accumulation
GEMM().execute(a_shared_tensor, b_shared_tensor, c_register_fragment);
// Partition Global C for GEMM and store appropriate elements to global memory
cublasdx::copy_fragment<alignment::c>(c_register_fragment, c_global_tensor, partitioner);
}
Return Value register GEMM API#
A typical structure of a Return Value register API GEMM
kernel is as follows:
#include <cublasdx.hpp>
using namespace cublasdx;
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Function<function::MM>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ SM<700>()
+ Block());
// Type <a/b/c>_value_type is defined based on the GEMM description. Precision operator defines its numerical
// precision, and via Type operator user specifies if it is complex or real.
//
// In this case, a/b/c_value_type are all double since set precision is double, and type is real.
using a_value_type = typename GEMM::a_value_type;
using b_value_type = typename GEMM::b_value_type;
using c_value_type = typename GEMM::c_value_type;
__global__ void gemm_kernel(c_value_type alpha, a_value_type *a, b_value_type *b, c_value_type beta, c_value_type *c) {
extern __shared__ __align__(16) char smem[];
// Create global memory tensor
// a_global_tensor = (from a)
// b_global_tensor = (from b)
// c_global_tensor = (from c)
// Make shared memory tensor
// a_shared_tensor = (from smem)
// b_shared_tensor = (from smem + ...)
// Load data from global memory tensor to shared memory tensor
// a_shared_tensor <-- a_global_tensor
// b_shared_tensor <-- b_global_tensor
// Execute GEMM
auto [c_register_fragment, ...] =
GEMM().execute(alpha, a_shared_tensor, b_shared_tensor, beta, c_shared_tensor);
// Partition Global C for GEMM and store appropriate elements to global memory
// c_global_tensor <-- c_register_fragment
}
This API doesn’t expect the register fragment upfront, but returns it as a result:
Create global and shared memory tensors (see Tensor Creation).
Copy data from global memory tensors to shared memory tensors (see Copying Tensors).
(main step) Execute
GEMM
using the tensor APIs, getting results as a register fragment.Copy data from register accumulator tensor to appropriate places in global memory tensor (see Copying Tensors).
After filling all these steps in with tensor creation and copying code, we get:
#include <cublasdx.hpp>
using namespace cublasdx;
template<class GEMM>
__global__ void gemm_kernel_registers(const typename GEMM::a_value_type* a,
const typename GEMM::b_value_type* b,
typename GEMM::c_value_type* c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b] = cublasdx::slice_shared_memory_ab<GEMM>(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy_wait();
// Execute GEMM and get register fragment results and data partitioner in return
auto [c_register_fragment, partitioner] = GEMM().execute(a_shared_tensor, b_shared_tensor);
// Partition Global C for GEMM and store appropriate elements to global memory
cublasdx::copy_fragment<alignment::c>(c_register_fragment, c_global_tensor, partitioner);
}
Launching GEMM Kernel#
To launch a kernel executing the defined GEMM
we need to know the required block dimensions and the amount of shared memory needed for all
three matrices - A
, B
, C
. Elements in the matrix A
should be in a row-major format, and matrices B
and C
in a column-major format, accounting for leading dimensions.
#include <cublasdx.hpp>
using namespace cublasdx;
// Kernels are unfolded in their appropriate sections above
template<class GEMM>
__global__ void gemm_kernel_shared(GEMM::c_value_type alpha, GEMM::a_value_type *a, GEMM::b_value_type *b, GEMM::c_value_type beta, GEMM::c_value_type *c)
{
...
}
template<class GEMM>
__global__ void gemm_kernel_registers_accumulation(GEMM::a_value_type *a, GEMM::b_value_type *b, GEMM::c_value_type *c);
{
...
}
template<class GEMM>
__global__ void gemm_kernel_registers(GEMM::a_value_type *a, GEMM::b_value_type *b, GEMM::c_value_type *c);
{
...
}
// CUDA_CHECK_AND_EXIT - marco checks if function returns cudaSuccess; if not it prints the error code and exits the program
void introduction_example(value_type alpha, value_type *a, value_type *b, value_type beta, value_type *c) {
using GEMM = decltype(Size<32, 32, 32>()
+ Precision<double>()
+ Type<type::real>()
+ Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ Function<function::MM>());
+ SM<700>()
+ Block());
// Shared memory API: C = alpha * A * B + beta * C
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_shared<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size<GEMM>()>>>(1.0, a, b, 1.0, c);
// Register fragment Accumulation API: C = A * B + C
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_registers_accumulation<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size_ab<GEMM>()>>>(a, b, c);
// Register fragment API: C = A * B
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_registers<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size_ab<GEMM>()>>>(a, b, c);
CUDA_CHECK_AND_EXIT(cudaPeekAtLastError());
CUDA_CHECK_AND_EXIT(cudaDeviceSynchronize());
}
The required shared memory can be obtained using cublasdx::get_shared_storage_size<GEMM>()
and cublasdx::get_shared_storage_size_ab<GEMM>()
.
It accounts for any padding declared using LeadingDimension Operator and resulting from Alignment Operator.
For simplicity, in the example we allocate managed memory for device matrices, assume that Volta architecture is used, and don’t check CUDA error codes returned by CUDA API functions. Please check the full introduction_example.cu example, as well as others shipped with cuBLASDx, for more detailed code.
#include <iostream>
#include <vector>
#include <cuda_runtime_api.h>
#include <cublasdx.hpp>
#include "common.hpp"
#include "reference.hpp"
template<class GEMM>
__global__ void gemm_kernel_shared(const typename GEMM::c_value_type alpha,
const typename GEMM::a_value_type* a,
const typename GEMM::b_value_type* b,
const typename GEMM::c_value_type beta,
typename GEMM::c_value_type* c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b, smem_c] = cublasdx::slice_shared_memory<GEMM>(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
auto c_shared_tensor = cublasdx::make_tensor(smem_c, GEMM::get_layout_smem_c());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy<GEMM, alignment::c>(c_global_tensor, c_shared_tensor);
cublasdx::copy_wait();
// Execute GEMM
GEMM().execute(alpha, a_shared_tensor, b_shared_tensor, beta, c_shared_tensor);
__syncthreads();
// Store data from shared memory tensor to global memory tensor
cublasdx::copy<GEMM, alignment::c>(c_shared_tensor, c_global_tensor);
}
template<class GEMM>
__global__ void gemm_kernel_registers_accumulation(const typename GEMM::a_value_type* a,
const typename GEMM::b_value_type* b,
typename GEMM::c_value_type* c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b] = cublasdx::slice_shared_memory_ab<GEMM>(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy_wait();
// Get default partitioner
auto partitioner = GEMM::get_partitioner();
// Create register fragment Accumulator
auto c_register_fragment = partitioner.make_accumulator_fragment();
// Partition Global C for GEMM and load appropriate elements into register fragment
cublasdx::copy_fragment<alignment::c>(c_global_tensor, c_register_fragment, partitioner);
// Execute GEMM with accumulation
GEMM().execute(a_shared_tensor, b_shared_tensor, c_register_fragment);
// Partition Global C for GEMM and store appropriate elements to global memory
cublasdx::copy_fragment<alignment::c>(c_register_fragment, c_global_tensor, partitioner);
}
template<class GEMM>
__global__ void gemm_kernel_registers(const typename GEMM::a_value_type* a,
const typename GEMM::b_value_type* b,
typename GEMM::c_value_type* c) {
extern __shared__ __align__(16) char smem[];
// Make global memory tensor
auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());
// Make shared memory tensor
auto [smem_a, smem_b] = cublasdx::slice_shared_memory_ab<GEMM>(smem);
auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
// Load data from global memory tensor to shared memory tensor
using alignment = cublasdx::alignment_of<GEMM>;
cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
cublasdx::copy_wait();
// Execute GEMM and get register fragment results and data partitioner in return
auto [c_register_fragment, partitioner] = GEMM().execute(a_shared_tensor, b_shared_tensor);
// Partition Global C for GEMM and store appropriate elements to global memory
cublasdx::copy_fragment<alignment::c>(c_register_fragment, c_global_tensor, partitioner);
}
template<unsigned int Arch>
int introduction_example() {
using GEMM = decltype(cublasdx::Size<32, 32, 32>()
+ cublasdx::Precision<double>()
+ cublasdx::Type<cublasdx::type::real>()
+ cublasdx::Arrangement<cublasdx::row_major, cublasdx::col_major>()
+ cublasdx::Function<cublasdx::function::MM>()
+ cublasdx::SM<700>()
+ cublasdx::Block()
+ cublasdx::BlockDim<256>());
using value_type = typename example::uniform_value_type_t<GEMM>;
constexpr auto global_a_size = example::global_memory_size_of<GEMM>::a_size;
constexpr auto global_b_size = example::global_memory_size_of<GEMM>::b_size;
constexpr auto global_c_size = example::global_memory_size_of<GEMM>::c_size;
// Allocate managed memory for A, B, C matrices in one go
value_type* abc;
auto size = global_a_size + global_b_size + global_c_size;
auto size_bytes = size * sizeof(value_type);
CUDA_CHECK_AND_EXIT(cudaMallocManaged(&abc, size_bytes));
// Generate data
for (size_t i = 0; i < size; i++) {
abc[i] = double(i / size);
}
value_type* a = abc;
value_type* b = abc + global_a_size;
value_type* c = abc + global_a_size + global_b_size;
// Shared memory API: C = alpha * A * B + beta * C
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_shared<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size<GEMM>()>>>(1.0, a, b, 0.5, c);
// Register fragment Accumulation API: C = A * B + C
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_registers_accumulation<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size_ab<GEMM>()>>>(a, b, c);
// Register fragment API: C = A * B
// Invokes kernel with GEMM::block_dim threads in CUDA block
gemm_kernel_registers<GEMM><<<1, GEMM::block_dim, cublasdx::get_shared_storage_size_ab<GEMM>()>>>(a, b, c);
CUDA_CHECK_AND_EXIT(cudaPeekAtLastError());
CUDA_CHECK_AND_EXIT(cudaDeviceSynchronize());
CUDA_CHECK_AND_EXIT(cudaFree(abc));
std::cout << "Success" << std::endl;
return 0;
}
struct introduction_example_functor {
template<int Arch>
int operator()(std::integral_constant<int, Arch>) {
return introduction_example<Arch>();
}
};
int main(int, char**) {
return example::sm_runner(introduction_example_functor{});
}
It is important to notice that unlike the cuBLAS library cuBLASDx does not require moving data back to global memory after executing a BLAS operation. Nor does it require the input data to be loaded from global memory. Those properties can be a major performance advantage for certain use-cases. The list of possible optimizations includes but is not limited to:
Fusing BLAS routines with custom pre- and post-processing.
Fusing multiple BLAS operations together.
Fusing BLAS and FFT operations (using cuFFTDx) together.
Generating input matrices or parts of them.
Compilation#
For instructions on how to compile programs with cuBLASDx see Quick Installation Guide.