Execution Methods#

Execution methods are used to run the BLAS functions as defined by user with cuBLASDx operators.

Note

Right now, cuBLASDx supports only execution on CUDA thread block level (block execution).

Block Execute Method#

The block execution methods are available if the descriptor has been constructed using the Block Operator and is_complete_blas_execution Trait is true.

Shared Memory API#

Method execute(...) runs the calculations defined by the BLAS descriptor, accepting three types of arguments.

using BLAS = decltype(cublasdx::Size<M, N, K>() + ...);

// #1 - Tensor API

template<class Alpha,                         // Must be convertible to BLAS::c_value_type
         class AEngine, class ALayout,        // Types derived from pointer and layout used to create tensor_a
         class BEngine, class BLayout,        // Types derived from pointer and layout used to create tensor_b
         class Beta,                          // Must be convertible to BLAS::c_value_type
         class CEngine, class CLayout,        // Types derived from pointer and layout used to create tensor_c
         class ALoadOp = cublasdx::identity,  // Transform operation applied when data is loaded from matrix A
         class BLoadOp = cublasdx::identity,  // Transform operation applied when data is loaded from matrix B
         class CLoadOp = cublasdx::identity,  // Transform operation applied when data is loaded from matrix C
         class CStoreOp = cublasdx::identity> // Transform operation applied when data is store to matrix C
inline __device__ void execute(const Alpha&                               alpha,
                               const cublasdx::tensor<AEngine, ALayout>&  tensor_a,
                               const cublasdx::tensor<BEngine, BLayout>&  tensor_b,
                               const Beta&                                beta,
                               cublasdx::tensor<CEngine, CLayout>&        tensor_c,
                               const ALoadOp&                             a_load_op  = {},
                               const BLoadOp&                             b_load_op  = {},
                               const CLoadOp&                             c_load_op  = {},
                               const CStoreOp&                            c_store_op = {})


// #2 - Pointer API
template<
  class Alpha, // Must be convertible to BLAS::c_value_type
  class TA,    // Value type of matrix A
  class TB,    // Value type of matrix B
  class Beta,  // Must be convertible to BLAS::c_value_type
  class TC,    // Value type of matrix C
  class ALoadTransformOp  = cublasdx::identity, // Transform operation applied when data is loaded from matrix A
  class BLoadTransformOp  = cublasdx::identity, // Transform operation applied when data is loaded from matrix B
  class CLoadTransformOp  = cublasdx::identity, // Transform operation applied when data is loaded from matrix C
  class CStoreTransformOp = cublasdx::identity> // Transform operation applied when data is store to matrix C
 inline __device__ void execute(const Alpha     alpha,
                                TA*             matrix_a,
                                TB*             matrix_b,
                                const Beta      beta,
                                TC*             matrix_c,
 inline __device__ auto execute(const Alpha              alpha,
                                TA*                      matrix_a,
                                TB*                      matrix_b,
                                const Beta               beta,
                                TC*                      matrix_c,
                                const ALoadTransformOp&  a_load_op  = {},
                                const BLoadTransformOp&  b_load_op  = {},
                                const CLoadTransformOp&  c_load_op  = {},
                                const CStoreTransformOp& c_store_op = {})

// #3 - Pointer API, which allows providing runtime/dynamic leading dimensions for matrices A, B, and C
template<
  class Alpha,
  class TA,
  class TB,
  class Beta,
  class TC,
  class ALoadTransformOp  = cublasdx::identity,
  class BLoadTransformOp  = cublasdx::identity,
  class CLoadTransformOp  = cublasdx::identity,
  class CStoreTransformOp = cublasdx::identity>
inline __device__ void BLAS::execute(const Alpha&             alpha,
                                     TA*                      matrix_a,
                                     const unsigned int       lda,
                                     TB*                      matrix_b,
                                     const unsigned int       ldb,
                                     const Beta&              beta,
                                     TC*                      matrix_c,
                                     const unsigned int       ldc,
                                     const ALoadTransformOp&  a_load_op  = {},
                                     const BLoadTransformOp&  b_load_op  = {},
                                     const CLoadTransformOp&  c_load_op  = {},
                                     const CStoreTransformOp& c_store_op = {})

Method #1 accepts cublasdx::tensor as representations of share memory storage for matrices A, B, and C. cublasdx::tensor is essentially CuTe tensor (cute::Tensor), a representation of the multi-dimensional array, with rich functionality abstracting away the details of how the array’s elements are organized and stored in memory.

See cuBLASDx Tensor, Tensor Creation, and Get Memory Layout for how to create tensor using raw memory pointer and CuTe layout, and if needed, dynamically defined leading dimensions. If necessary, user can pass tensors with custom layouts.

In method #2 and #3 pointers matrix_a, matrix_b, matrix_c must point to shared memory regions aligned to BLAS::<a/b/c>_alignment. If Alignment operator was not used, BLAS::<a/b/c>_alignment is equal to alignof(BLAS::<a/b/c>_value_type).

Methods #2 and #3 assume the layout of each matrix corresponds to the arrangement set for that matrix, i.e. if Arrangement<col_major, row_major, col_major> was used in the BLAS description A matrix should be column-major, B matrix - row-major, and C matrix - column-major. The default arrangement corresponds to using Arrangement<row_major, col_major, col_major>.

Method #3 allows user to provide custom dynamic leading dimensions via lda, ldb, and ldc arguments. In this case, leading dimension values set via LeadingDimension operator are ignored. Values lda, ldb, and ldc have to follow the same rules as presented in LeadingDimension operator.

After the execution function user has to perform CUDA block synchronization before accessing A, B, or C.

The code example below shows how the three execute(...) methods can be used.

#include <cublasdx.hpp>

using GEMM = decltype(cublasdx::Size<32, 32, 32>()
              + cublasdx::Precision<cublasdx::tfloat32_t, cublasdx::tfloat32_t, float>()
              + cublasdx::Type<cublasdx::type::real>()
              + cublasdx::Arrangement<cublasdx::row_major, cublasdx::col_major>()
              + cublasdx::Function<cublasdx::function::MM>()
              + cublasdx::MaxAlignment() // max alignment (16, 16, 16) is the default
              + cublasdx::SM<800>()
              + cublasdx::Block());

using a_data_type = typename GEMM::a_value_type;
using b_data_type = typename GEMM::b_value_type;
using c_data_type = typename GEMM::c_value_type;

extern __shared__ __align__(16) char smem[];

// smem_<a/b/c> are aligned to cublasdx::alignment_of<GEMM>::<a/b/c>
auto [smem_a, smem_b, smem_c] = cublasdx::slice_shared_memory<GEMM>(smem);

//*********** Method #1, using cublasdx tensor APIs
{
    // Make global memory tensor
    auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
    auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
    auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());

    // Make shared memory tensor
    auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
    auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
    auto c_shared_tensor = cublasdx::make_tensor(smem_c, GEMM::get_layout_smem_c());

    // Load data from global to shared memory using cublasdx::copy API
    using alignment = cublasdx::alignment_of<GEMM>;
    cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
    cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
    cublasdx::copy<GEMM, alignment::c>(c_global_tensor, c_shared_tensor);
    cublasdx::copy_wait();

    // Execute
    GEMM().execute(alpha, a_shared_tensor, b_shared_tensor, beta, c_shared_tensor);
    __syncthreads();

    // Store back to global memory using cublasdx::copy API
    cublasdx::copy<GEMM, alignment::c>(c_shared_tensor, c_global_tensor);
    cublasdx::copy_wait(); // Needed to ensure c_global_tensor has a defined state and data in it can be used for any following operations in the kernel. If there are no further instruction a kernel's finalization will be the final synchronization point.
}

//*********** Method #1, cublasdx tensor APIs, with dynamic leading dimensions
{
    // Make global memory tensor
    auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a(lda));
    auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b(ldb));
    auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c(ldb));

    // Make shared memory tensor
    auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a(lda));
    auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b(ldb));
    auto c_shared_tensor = cublasdx::make_tensor(smem_c, GEMM::get_layout_smem_c(ldc));

    // Load data from global to shared memory using cublasdx::copy API
    using alignment = cublasdx::alignment_of<GEMM>;
    cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
    cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
    cublasdx::copy<GEMM, alignment::c>(c_global_tensor, c_shared_tensor);
    cublasdx::copy_wait();

    // Execute
    GEMM().execute(alpha, a_shared_tensor, b_shared_tensor, beta, c_shared_tensor);
    __syncthreads();

    // Store back to global memory using cublasdx::copy API
    cublasdx::copy<GEMM, alignment::c>(c_shared_tensor, c_global_tensor);
    cublasdx::copy_wait(); // Only needed if more operations on shared memory used in c_shared_tensor happens in the kernel
}

//*********** Method #2, using raw share memory pointers
{
    // User code to load data from global to shared memory
    // smem_a <-- a, smem_b <-- b, smem_c <-- c

    // Execute
    GEMM().execute(alpha, smem_a, smem_b, beta, smem_c);
    __syncthreads();

    // User code to store back to global memory
    // smem_c --> c
}

  //*********** Method #3, with dynamic leading dimensions
{
    // User code to load data from global to shared memory
    // smem_a <-- a, smem_b <-- b, smem_c <-- c

    // Execute
    GEMM().execute(alpha, smem_a, lda, smem_b, ldb, beta, smem_c, ldc);
    __syncthreads();

    // User code to store back to global memory
    // smem_c --> c
}

Register API#

Method execute(...) runs the calculations defined by the BLAS descriptor, accepting two types of arguments.

using BLAS = decltype(cublasdx::Size<M, N, K>() + ...);

// #1 - Registers with accumulator API

template<class AEngine, class ALayout,       // Types derived from pointer and layout used to create tensor_a
         class BEngine, class BLayout,       // Types derived from pointer and layout used to create tensor_b
         class CEngine, class CLayout,       // Types derived from pointer and layout used to create tensor_c
         class ALoadOp = cublasdx::identity, // Transform operation applied when data is loaded from matrix A
         class BLoadOp = cublasdx::identity> // Transform operation applied when data is loaded from matrix B
inline __device__ void execute(const cublasdx::tensor<AEngine, ALayout>&  tensor_a,
                               const cublasdx::tensor<BEngine, BLayout>&  tensor_b,
                               cublasdx::tensor<CEngine, CLayout>      &  tensor_c,
                               const ALoadOp&                             a_load_op  = {},
                               const BLoadOp&                             b_load_op  = {})

// #2 - Registers without accumulator API

template<class AEngine, class ALayout,
         class BEngine, class BLayout,
         class ALoadOp = cublasdx::identity,
         class BLoadOp = cublasdx::identity>
inline __device__ auto execute(const cublasdx::tensor<AEngine, ALayout>&  tensor_a,
                               const cublasdx::tensor<BEngine, BLayout>&  tensor_b,
                               const ALoadOp&                             a_load_op  = {},
                               const BLoadOp&                             b_load_op  = {})

Method #1 accepts shared memory tensors for A and B matrices, but a register fragment for C matrix. It returns nothing, as it will add the result of multiplying A with B to C, resulting in:

\(\mathbf{C}_{m\times n} = \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n} + \mathbf{C}_{m\times n}\)

The register fragment must exist beforehand, either from a previous execution of method #2, or created from a partitioner object (see Data Partitioner). It must match exactly the precision and partitioning of GEMM which it is used for.

Method #2 accepts only shared memory tensors for A and B matrices. It returns an opaque tuple containing resulting register fragment together with its corresponding partitioner. The results correspond to the result of:

\(\mathbf{C}_{m\times n} = \mathbf{A}_{m\times k} \times \mathbf{B}_{k\times n}\)

A C++17 structured binding can be used to cleanly retrieve both values:

auto [c_register_fragment, partitioner] = BLAS().execute(a_shared_tensor, b_shared_tensor, a_load_op, b_load_op);

The code example below shows how the two execute(...) methods can be used.

#include <cublasdx.hpp>

using GEMM = decltype(cublasdx::Size<32, 32, 32>()
              + cublasdx::Precision<cublasdx::tfloat32_t, cublasdx::tfloat32_t, float>()
              + cublasdx::Type<cublasdx::type::real>()
              + cublasdx::Arrangement<cublasdx::row_major, cublasdx::col_major>()
              + cublasdx::Function<cublasdx::function::MM>()
              + cublasdx::MaxAlignment() // max alignment (16, 16, 16) is the default
              + cublasdx::SM<800>()
              + cublasdx::Block());

using a_data_type = typename GEMM::a_value_type;
using b_data_type = typename GEMM::b_value_type;
using c_data_type = typename GEMM::c_value_type;

extern __shared__ __align__(16) char smem[];

// smem_<a/b> are aligned to cublasdx::alignment_of<GEMM>::<a/b>
auto [smem_a, smem_b] = cublasdx::slice_shared_memory_ab<GEMM>(smem);

//*********** Method #1, register API with accumulator
{
    // Make global memory tensor
    auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
    auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
    auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());

    // Make shared memory tensor
    auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
    auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());

    // Load data from global to shared memory using cublasdx::copy API
    using alignment = cublasdx::alignment_of<GEMM>;
    cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
    cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
    cublasdx::copy_wait();

    // Execute
    auto partitioner = GEMM::get_partitioner();
    auto c_register_fragment = partitioner.make_accumulator_fragment();
    cublasdx::copy_fragment<alignment::c>(c_global_tensor, c_register_fragment, partitioner);

    GEMM().execute(a_shared_tensor, b_shared_tensor, c_register_fragment);

    // Store back to global memory using cublasdx::copy_fragment API
    cublasdx::copy_fragment<alignment::c>(c_register_fragment, c_global_tensor, partitioner);
}

//*********** Method #2, cublasdx tensor APIs, without accumulator
{
    // Make global memory tensor
    auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
    auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
    auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());

    // Make shared memory tensor
    auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
    auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());

    // Load data from global to shared memory using cublasdx::copy API
    using alignment = cublasdx::alignment_of<GEMM>;
    cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
    cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
    cublasdx::copy_wait();

    // Execute
    auto [c_register_fragment, partitioner] = GEMM().execute(a_shared_tensor, b_shared_tensor);

    // Store back to global memory using cublasdx::copy_fragment API
    cublasdx::copy_fragment<alignment::c>(c_register_fragment, c_global_tensor, partitioner);
}

Input data properties#

Note

Starting from cuBLASDx 0.3.0, computational precision has been decoupled from data precision, i.e. the input / output data for each matrix can be of arbitrary type (even integral input for floating point GEMM) provided that Alignment Operator is set and at least one of those conditions is met:

  1. It’s implicitly convertible to the data type chosen with Precision Operator and Type Operator.

  2. For inputs: An appropriate converting loading operation is provided as one of the arguments. It takes the input type value. Its result must be at least implicitly convertible to the compute type.

  3. For output: An appropriate converting storing operation is provided as one of the arguments. It takes the result computational type (usually C type as defined by Precision Operator and Type Operator). Its result must be at least implicitly convertible to the output type.

Warning

If using computation precision decoupled from input types, the Alignment Operator must be explicitly set.

The underlying element types of scalars (alpha and beta) by default are assumed to be BLAS::c_value_type, but they can be any types as long as:

  1. the alignment and the size of each of them are the same as BLAS::c_value_type, and

  2. they are convertible to BLAS::c_value_type

Transform operation inputs#

All of the methods accept transform functors. a_load_op, b_load_op, c_load_op are applied as elements are read from each matrix, and c_store_op is applied before the results of matrix multiplication are stored in C matrix. Each functor has to represent an element-wise transform which:

  1. For load transformations: accepts respective input type to execute(...) method and returns value of type implicitly convertible to BLAS::<a/b/c>_value_type.

  2. For store transformations: accepts BLAS::<a/b/c>_value_type and returns respective input type to execute(...) method.

Example

using GEMM = decltype(Size<128, 128, 128>() + Type<type::real>() + Precision<float, float, double>() + Block() + ...);

struct multiple_by_2 {
  template<class T>
  __device__ constexpr T operator()(const T arg) const {
    return arg * static_cast<T>(2.0f);
  }
};

struct negate {
  template <class T>
  __device__ constexpr T operator()(const T arg) const {
    return -arg;
  }
};

GEMM().execute(..., multiple_by_2{}, cublasdx::conjugate{}, cublasdx::identity{}, negate{});

Warning

It is not guaranteed that executions of exactly the same BLAS function with exactly the same inputs but with different

will produce bit-identical results.

Warning

It is not guaranteed that executions of exactly the same BLAS function with exactly the same inputs on GPUs of different CUDA architectures will produce bit-identical results.

Value Format#

BLAS::a_value_type
BLAS::b_value_type
BLAS::c_value_type

For complex numbers of every precision, the first value in a complex number is the real part and the second is the imaginary part. For real number, BLAS::<a/b/c>_value_type is same as P in Precision<PA, PB, PC> used to describe BLAS (or the default precision).

Input/Output Data Format#

This section describes the input and output data format (layout) required for correct calculations.

GEMM (function::MM)#

The tensor API for general matrix multiplication (execute() method which expects matrices represented using cublasdx::tensor) accepts matrices represented by tensors with arbitrary layouts. Since the tensor object carries all the information about the dimensions, the memory location and layout of a matrix, no other implicit assumptions are needed. The dimensions of the matrices must match the dimensions defined by Size operator. See also Get Memory Layout and Suggested shared memory Layout sections.

The pointer API for general matrix multiplication (#2 and #3 overloads of execute()) assumes that values in input matrices matrix_a, matrix_b, matrix_c are stored as defined by the Arrangement operator added to the description (by default it’s row-major format for matrix_a, column-major for matrix_b, and column-major for matrix_c).

Shared memory Usage#

It’s important to note that large BLAS operations (as defined by Size operator) may require more than 48 KB of shared memory per CUDA block for the matrices. Therefore, as described in CUDA Programming Guide (#1, #2, #3), kernels with such BLAS operations must use the dynamic shared memory rather than statically sized shared memory arrays. Additionally, these kernels require an explicit opt-in using cudaFuncSetAttribute() to set the cudaFuncAttributeMaxDynamicSharedMemorySize. See example code below.

#include <cublasdx.hpp>
using namespace cublasdx;

using GEMM = decltype(cublasdx::Size<128, 128, 64>()
              + cublasdx::Precision<__nv_fp8_e4m3, __nv_fp8_e5m2, float>()
              + cublasdx::Type<cublasdx::type::real>()
              + cublasdx::Arrangement<cublasdx::row_major, cublasdx::col_major>()
              + cublasdx::Function<cublasdx::function::MM>()
              + cublasdx::SM<900>()
              + cublasdx::Block());

void example() {
  (...)

  // Get required shared memory sizes, options:

  // Shared Memory API
  // 1 - Shared memory size required for matrices based on GEMM definition
  auto shared_memory_size = cublasdx::get_shared_storage_size<GEMM>();
  // 2 - Shared memory size when dynamic leading dimensions are used
  auto shared_memory_size = cublasdx::get_shared_storage_size<GEMM>(lda, ldb, ldc);
  // 3 - Shared memory size calculated based on custom matrix layouts for A, B, C matrices
  auto shared_memory_size = cublasdx::get_shared_storage_size<GEMM>(matrix_a_layout, matrix_b_layout, matrix_c_layout);

  // Register API
  // 1 - Shared memory size required for matrices based on GEMM definition
  auto shared_memory_size = cublasdx::get_shared_storage_size_ab<GEMM>();
  // 2 - Shared memory size when dynamic leading dimensions are used
  auto shared_memory_size = cublasdx::get_shared_storage_size_ab<GEMM>(lda, ldb);
  // 3 - Shared memory size calculated based on custom matrix layouts for A, B matrices
  auto shared_memory_size = cublasdx::get_shared_storage_size_ab<GEMM>(matrix_a_layout, matrix_b_layout);

  // Increases the max dynamic shared memory size to match GEMM requirements
  cudaFuncSetAttribute(gemm_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)
  // Invokes kernel with GEMM::block_dim threads in CUDA block
  gemm_kernel<GEMM><<<1, GEMM::block_dim, shared_memory_size>>>(alpha, a, b, beta, c);

  (...)
}