Execution Methods

Execution methods are used to run the BLAS function as defined by user with cuBLASDx operators.

Note

Right now, cuBLASDx supports only execution on CUDA thread block level (block execution).

Block Execute Method

The block execution methods are available if the descriptor has been constructed using the Block Operator and is_complete_blas_execution Trait is true.

Method execute(...) runs the calculations defined by the BLAS descriptor, accepting three types of arguments.

using BLAS = decltype(cublasdx::Size<M, N, K>() + ...);

// #1 - Tensor API
//
template<
  class AEngine, class ALayout,
  class BEngine, class BLayout,
  class CEngine, class CLayout,
  class ALoadTransformOp  = cublasdx::identity, // Transform operation applied when data is loaded from matrix A
  class BLoadTransformOp  = cublasdx::identity, // Transform operation applied when data is loaded from matrix B
  class CLoadTransformOp  = cublasdx::identity, // Transform operation applied when data is loaded from matrix C
  class CStoreTransformOp = cublasdx::identity  // Transform operation applied when data is store to matrix C
>
void BLAS::execute(const typename CEngine::value_type& alpha,
                   cublasdx::tensor<AEngine, ALayout>  matrix_a,
                   cublasdx::tensor<BEngine, BLayout>  matrix_b,
                   const typename CEngine::value_type& beta,
                   cublasdx::tensor<CEngine, CLayout>  matrix_c,
                   const ALoadTransformOp&             a_load_op  = {},
                   const BLoadTransformOp&             b_load_op  = {},
                   const CLoadTransformOp&             c_load_op  = {},
                   const CStoreTransformOp&            c_store_op = {})

// #2 - Pointer API
template<
  class TA, // Value type of matrix A
  class TB, // Value type of matrix B
  class TC, // Value type of matrix C
  class ALoadTransformOp  = cublasdx::identity, // Transform operation applied when data is loaded from matrix A
  class BLoadTransformOp  = cublasdx::identity, // Transform operation applied when data is loaded from matrix B
  class CLoadTransformOp  = cublasdx::identity, // Transform operation applied when data is loaded from matrix C
  class CStoreTransformOp = cublasdx::identity  // Transform operation applied when data is store to matrix C
>
void BLAS::execute(const TC&                alpha,
                   TA*                      matrix_a,
                   TB*                      matrix_b,
                   const TC&                beta,
                   TC*                      matrix_c,
                   const ALoadTransformOp&  a_load_op  = {},
                   const BLoadTransformOp&  b_load_op  = {},
                   const CLoadTransformOp&  c_load_op  = {},
                   const CStoreTransformOp& c_store_op = {})

// #3 - Pointer API, which allows providing runtime/dynamic leading dimensions for matrices A, B, and C
template<
  class TA,
  class TB,
  class TC,
  class ALoadTransformOp  = cublasdx::identity,
  class BLoadTransformOp  = cublasdx::identity,
  class CLoadTransformOp  = cublasdx::identity,
  class CStoreTransformOp = cublasdx::identity
>
void BLAS::execute(const TC&                alpha,
                   TA*                      matrix_a,
                   const unsigned int       lda,
                   TB*                      matrix_b,
                   const unsigned int       ldb,
                   const TC&                beta,
                   TC*                      matrix_c,
                   const unsigned int       ldc,
                   const ALoadTransformOp&  a_load_op  = {},
                   const BLoadTransformOp&  b_load_op  = {},
                   const CLoadTransformOp&  c_load_op  = {},
                   const CStoreTransformOp& c_store_op = {})

Method #1 accepts cublasdx::tensor as representations of share memory storage for matrices A, B, and C. This is a new feature of cuBLASDx 0.2.0. cublasdx::tensor is essentially CuTe tensor (cute::Tensor), a representation of the multi-dimensional array, with rich functionality abstracting away the details of how the array’s elements are organized and stored in memory. The underlying element types of tensors (cublasdx::tensor<Engine, Layout>::value_type) by default are assumed to be BLAS::<a/b/c>_value_type, but they can be any types as long as the alignment and the size of each of them are the same as the corresponding BLAS::<a/b/c>_value_type type. See Tensor, Tensor Creation, and Get Memory Layout for how to create tensor using raw memory pointer and CuTe layout, and if needed, dynamically defined leading dimensions. If necessary, user can pass tensors with custom layouts.

In Method #2 and #3, by default T<A/B/C> is BLAS::<a/b/c>_value_type, but it can be any type (such as float2, cuda::std::complex<double>), as long as its alignment and size are the same as those of BLAS::. Pointers matrix_a, matrix_b, matrix_c must point to shared memory regions aligned to BLAS::<a/bc/>_alignment. If Alignment operator was not used, BLAS::<a/bc/>_alignment is equal to alignof(BLAS::<a/b/c>_value_type).

Methods #2 and #3 assumes the layout of each matrix corresponds to the arrangement set for that matrix, i.e. if Arrangement<col-major, row-major, col-major> was used in the BLAS description A matrix should be column-major, B matrix - row-major, and C matrix - column-major. The default arrangement corresponds to using Arrangement<row-major, col-major, col-major>.

Method #3 allows user to provide custom dynamic leading dimensions via lda, ldb, and ldc arguments. In this case, leading dimension values set via LeadingDimension operator are ignored. Values lda, ldb, and ldc have to follow the same rules as presented in LeadingDimension operator.

After the execution function user has to perform CUDA block synchronization before accessing matrix_a, matrix_b, matrix_c.

The code example below shows how the three execute(...) methods can be used.

#include <cublasdx.hpp>

using GEMM = decltype(cublasdx::Size<32, 32, 32>()
              + cublasdx::Precision<cublasdx::tfloat32_t, cublasdx::tfloat32_t, float>()
              + cublasdx::Type<cublasdx::type::real>()
              + cublasdx::Arrangement<cublasdx::row_major, cublasdx::col_major>()
              + cublasdx::Function<cublasdx::function::MM>()
              + cublasdx::MaxAlignment() // max alignment (16, 16, 16) is the default
              + cublasdx::SM<800>()
              + cublasdx::Block());

using a_data_type = typename GEMM::a_value_type;
using b_data_type = typename GEMM::b_value_type;
using c_data_type = typename GEMM::c_value_type;

extern __shared__ __align__(16) char smem[];

// smem_<a/b/c> are aligned to cublasdx::alignment_of<GEMM>::<a/b/c>
auto [smem_a, smem_b, smem_c] = GEMM::slice_shared_memory(smem);

//*********** Method #1, using cublasdx tensor APIs
{
    // Make global memory tensor
    auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a());
    auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b());
    auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c());

    // Make shared memory tensor
    auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a());
    auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b());
    auto c_shared_tensor = cublasdx::make_tensor(smem_c, GEMM::get_layout_smem_c());

    // Load data from global to shared memory using cublasdx::copy API
    using alignment = cublasdx::alignment_of<GEMM>;
    cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
    cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
    cublasdx::copy<GEMM, alignment::c>(c_global_tensor, c_shared_tensor);
    cublasdx::copy_wait();

    // Execute
    GEMM().execute(alpha, a_shared_tensor, b_shared_tensor, beta, c_shared_tensor);
    __syncthreads();

    // Store back to global memory using cublasdx::copy API
    cublasdx::copy<GEMM, alignment::c>(c_shared_tensor, c_global_tensor);
    cublasdx::copy_wait(); // Only needed if more operations on shared memory used in c_shared_tensor happens in the kernel
}

//*********** Method #1, cublasdx tensor APIs, with dynamic leading dimensions
{
    // Make global memory tensor
    auto a_global_tensor = cublasdx::make_tensor(a, GEMM::get_layout_gmem_a(lda));
    auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b(ldb));
    auto c_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_c(ldb));

    // Make shared memory tensor
    auto a_shared_tensor = cublasdx::make_tensor(smem_a, GEMM::get_layout_smem_a(lda));
    auto b_shared_tensor = cublasdx::make_tensor(smem_b, GEMM::get_layout_smem_b(ldb));
    auto c_shared_tensor = cublasdx::make_tensor(smem_c, GEMM::get_layout_smem_c(ldc));

    // Load data from global to shared memory using cublasdx::copy API
    using alignment = cublasdx::alignment_of<GEMM>;
    cublasdx::copy<GEMM, alignment::a>(a_global_tensor, a_shared_tensor);
    cublasdx::copy<GEMM, alignment::b>(b_global_tensor, b_shared_tensor);
    cublasdx::copy<GEMM, alignment::c>(c_global_tensor, c_shared_tensor);
    cublasdx::copy_wait();

    // Execute
    GEMM().execute(alpha, a_shared_tensor, b_shared_tensor, beta, c_shared_tensor);
    __syncthreads();

    // Store back to global memory using cublasdx::copy API
    cublasdx::copy<GEMM, alignment::c>(c_shared_tensor, c_global_tensor);
    cublasdx::copy_wait(); // Only needed if more operations on shared memory used in c_shared_tensor happens in the kernel
}

//*********** Method #2, using raw share memory pointers
{
    // User code to load data from global to shared memory
    // smem_a <-- a, smem_b <-- b, smem_c <-- c

    // Execute
    GEMM().execute(alpha, smem_a, smem_b, beta, smem_c);
    __syncthreads();

    // User code to store back to global memory
    // smem_a --> a, smem_b --> b, smem_c --> c
}

  //*********** Method #3, with dynamic leading dimensions
{
    // User code to load data from global to shared memory
    // smem_a <-- a, smem_b <-- b, smem_c <-- c

    // Execute
    GEMM().execute(alpha, smem_a, lda, smem_b, ldb, beta, smem_c, ldc);
    __syncthreads();

    // User code to store back to global memory
    // smem_a --> a, smem_b --> b, smem_c --> c
}

All of the methods accept four transform functors. a_load_op, b_load_op, c_load_op are applied as elements are read from each matrix, and c_store_op is applied before the results of matrix multiplication are stored in C matrix. Each functor has to represent an element-wise transform which accepts value of type BLAS::<a/b/c>_value_type and returns value of type convertible to BLAS::<a/b/c>_value_type.

Example

using GEMM = decltype(Size<128, 128, 128>() + Type<type::real>() + Precision<float, float, double>() + Block() + ...);

struct multiple_by_2 {
  template<class T>
  __device__ constexpr T operator()(const T arg) const {
    return arg * static_cast<T>(2.0f);
  }
};

struct negate {
  template <class T>
  __device__ constexpr T operator()(const T arg) const {
    return -arg;
  }
};

GEMM().execute(..., multiple_by_2{}, cublasdx::conjugate{}, cublasdx::identity{}, negate{});

Warning

It is not guaranteed that executions of exactly the same BLAS function with exactly the same inputs but with different

will produce bit-identical results.

Warning

It is not guaranteed that executions of exactly the same BLAS function with exactly the same inputs on GPUs of different CUDA architectures will produce bit-identical results.

Value Format

BLAS::a_value_type
BLAS::b_value_type
BLAS::c_value_type

For complex numbers of every precision, the first value in a complex number is the real part and the second is the imaginary part. For real number, BLAS::<a/b/c>_value_type is same as P in Precision<PA, PB, PC> used to describe BLAS (or the default precision).

Input/Output Data Format

This section describes the input and output data format (layout) required for correct calculations.

GEMM (function::MM)

The tensor API for general matrix multiplication (execute() method which expects matrices represented using cublasdx::tensor) accepts matrices represented by tensors with arbitrary layouts. Since the tensor object carries all the information about the dimensions, the memory location and layout of a matrix, no other implicit assumptions are needed. The dimensions of the matrices must match the dimensions defined by Size operator. See also Get Memory Layout and Suggested Shared Memory Layout sections.

The pointer API for general matrix multiplication (#2 and #3 overloads of execute()) assumes that values in input matrices matrix_a, matrix_b, matrix_c are stored as defined by the Arrangement operator added to the description (by default it’s row-major format for matrix_a, col-major for matrix_b, and column-major for matrix_c).

Shared Memory Usage

It’s important to note that large BLAS operations (as defined by Size operator) may require more than 48 KB of shared memory per CUDA block for the matrices. Therefore, as described in CUDA Programming Guide (#1, #2, #3), kernels with such BLAS operations must use the dynamic shared memory rather than statically sized shared memory arrays. Additionally, these kernels require an explicit opt-in using cudaFuncSetAttribute() to set the cudaFuncAttributeMaxDynamicSharedMemorySize. See example code below.

BLAS::shared_memory_size and BLAS::get_shared_memory_size provide the amount of shared memory required for a specific BLAS operation.

#include <cublasdx.hpp>
using namespace cublasdx;

using GEMM = decltype(cublasdx::Size<128, 128, 64>()
              + cublasdx::Precision<__nv_fp8_e4m3, __nv_fp8_e5m2, float>()
              + cublasdx::Type<cublasdx::type::real>()
              + cublasdx::Arrangement<cublasdx::row_major, cublasdx::col_major>()
              + cublasdx::Function<cublasdx::function::MM>()
              + cublasdx::SM<900>()
              + cublasdx::Block());

void example() {
  (...)

  // Get required shared memory sizes, options:
  // 1 - Shared memory size required for matrices based on GEMM definition
  auto shared_memory_size = GEMM::shared_memory_size;
  auto shared_memory_size = GEMM::get_shared_memory_size(); // Same as GEMM::shared_memory_size
  // 2 - Shared memory size when dynamic leading dimensions are used
  auto shared_memory_size = GEMM::get_shared_memory_size(lda, ldb, ldc);
  // 3 - Shared memory size calculated based on custom matrix layouts for A, B, C matrices
  auto shared_memory_size = GEMM::get_shared_memory_size(matrix_a_layout, matrix_b_layout, matrix_c_layout);

  // Increases the max dynamic shared memory size to match GEMM requirements
  cudaFuncSetAttribute(gemm_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)
  // Invokes kernel with GEMM::block_dim threads in CUDA block
  gemm_kernel<GEMM><<<1, GEMM::block_dim, shared_memory_size>>>(alpha, a, b, beta, c);

  (...)
}