Shared Memory Management#
cuBLASDx’s GEMM requires input matrices to be in shared memory with an option for C matrix to be partitioned between threads’ registers. The use of shared memory imposes certain rules (alignment) and limitations (limited shared memory space). cuBLASDx provides shared memory management tools to make it easier for all users to operate on it correctly.
Warning
Starting with cuBLASDx 0.3.0 the ::shared_memory_size
traits and ::shared_memory_size()
methods no longer exist and have
been replaced with new APIs described in this chapter.
Shared Storage Size Utilities#
CUDA requires the user to specify amount of used dynamic shared memory upfront during kernel launch. For executing a GEMM such size will be dependent on:
Problem size (
Size<M, N, K>
).Chosen input type (since cuBLASDx 0.3.0 input precision can be different than compute precision, see Precision Operator).
Chosen alignments of matrices (
Alignment<A, B, C>
).API which will be used to execute GEMM (register or shared memory).
Since this information is neither contained in the BLAS
type (since input precision is decoupled from compute precision)
nor in the tensors (they are missing alignment information), helper utilities have been created to help with streamlining the process:
// Shared memory API
template<class BLAS, class AValueType = typename BLAS::a_value_type,
class BValueType = typename BLAS::b_value_type,
class CValueType = typename BLAS::c_value_type,
class ALayout, class BLayout, class CLayout>
__host__ __device__ __forceinline__ constexpr unsigned
get_shared_storage_size(ALayout const& a_layout,
BLayout const& b_layout,
CLayout const& c_layout);
template<class BLAS, class AValueType = typename BLAS::a_value_type,
class BValueType = typename BLAS::b_value_type,
class CValueType = typename BLAS::c_value_type>
__host__ __device__ __forceinline__ constexpr unsigned
get_shared_storage_size(unsigned lda = leading_dimension_of<BLAS>::a,
unsigned ldb = leading_dimension_of<BLAS>::b,
unsigned ldc = leading_dimension_of<BLAS>::c);
// Register API
template<class BLAS, class AValueType = typename BLAS::a_value_type,
class BValueType = typename BLAS::b_value_type,
class ALayout, class BLayout>
__host__ __device__ __forceinline__ constexpr unsigned
get_shared_storage_size_ab(ALayout const& a_layout,
BLayout const& b_layout);
template<class BLAS, class AValueType = typename BLAS::a_value_type,
class BValueType = typename BLAS::b_value_type>
__host__ __device__ __forceinline__ constexpr unsigned
get_shared_storage_size_ab(unsigned lda = leading_dimension_of<BLAS>::a,
unsigned ldb = leading_dimension_of<BLAS>::b);
Resulting value is size of the shared memory in bytes required to allocate input and output matrices, and perform computations.
The value is determined by size of value types, matrix sizes and alignments.
Note that cublasdx::get_shared_storage_size
accepts arbitrary CuTe layouts.
Class ALayout
, BLayout
and CLayout
in the above function prototype could be either
cute::Layout or
cute::ComposedLayout.
These functions can be used as follows:
// Shared API - Regular execution
auto shared_size = cublasdx::get_shared_storage_size<BLAS>();
// Shared API - Decoupled input precision execution
auto shared_size = cublasdx::get_shared_storage_size<BLAS, InputTypeA, InputTypeB, InputTypeC>();
// Shared API - Regular execution
auto shared_size = cublasdx::get_shared_storage_size_ab<BLAS>();
// Shared API - Decoupled input precision execution
auto shared_size = cublasdx::get_shared_storage_size_ab<BLAS, InputTypeA, InputTypeB, InputTypeC>();
For special cases cuBLASDx offers a shared memory size calculator:
shared_storage_calculator make_shared_storage_calculator();
it exposes the following API:
template<class Layout>
shared_storage_calculator& add(unsigned alignment, unsigned elem_size, const Layout& layout);
__host__ __device__ __forceinline__ constexpr
shared_storage_calculator& add(unsigned alignment, unsigned matrix_size_bytes);
__host__ __device__ __forceinline__ constexpr
shared_storage_calculator& add(unsigned alignment, unsigned elem_size, unsigned num_elements);
It can be used to calculate shared memory requirements for pipelined register API execution:
// 2 Stage pipelined register memory execution
auto shared_memory_size =
cublasdx::make_shared_storage_calculator()
.add(cublasdx::alignment_of_v_a<BLAS>, sizeof(AInputType), BLAS::suggest_layout_smem_a())
.add(cublasdx::alignment_of_v_b<BLAS>, sizeof(BInputType), BLAS::suggest_layout_smem_b())
.add(cublasdx::alignment_of_v_a<BLAS>, sizeof(AInputType), BLAS::suggest_layout_smem_a())
.add(cublasdx::alignment_of_v_b<BLAS>, sizeof(BInputType), BLAS::suggest_layout_smem_b())
.get();
Shared Memory Slicing#
Warning
Starting with cuBLASDx 0.3.0 the ::slice_shared_memory()
methods no longer exist and have
been moved to APIs described below.
The shared memory slicing free functions work with BLAS
if its is_complete_blas_execution Trait is true
.
// #1a Slice shared memory with default leading dimensions and default matrices layouts for data from A, B and C matrices
template<class BLAS, class AValueType = typename BLAS::a_value_type,
class BValueType = typename BLAS::b_value_type,
class CValueType = typename BLAS::c_value_type>
cute::tuple<AValueType*, BValueType*, CValueType*>
cublasdx::slice_shared_memory(void* smem_ptr);
// #1b Slice shared memory with default leading dimensions and default matrices layouts for data from A and B matrices
template<class BLAS, class AValueType = typename BLAS::a_value_type,
class BValueType = typename BLAS::b_value_type,
class CValueType = typename BLAS::c_value_type>
cute::tuple<AValueType*, BValueType*>
cublasdx::slice_shared_memory_ab(void* smem_ptr);
// #2a: Slice shared memory with dynamic leading dimensions
template<class BLAS, class AValueType = typename BLAS::a_value_type,
class BValueType = typename BLAS::b_value_type,
class CValueType = typename BLAS::c_value_type>
cute::tuple<AValueType*, BValueType*, CValueType*>
cublasdx::slice_shared_memory(void* smem_ptr,
unsigned int lda,
unsigned int ldb,
unsigned int ldc);
// #2b: Slice shared memory with dynamic leading dimensions
template<class BLAS, class AValueType = typename BLAS::a_value_type,
class BValueType = typename BLAS::b_value_type>
cute::tuple<AValueType*, BValueType*>
cublasdx::slice_shared_memory_ab(void* smem_ptr,
unsigned int lda,
unsigned int ldb);
// #3a: Slice shared memory with custom matrices layouts
template<class BLAS, class AValueType = typename BLAS::a_value_type,
class BValueType = typename BLAS::b_value_type,
class CValueType = typename BLAS::c_value_type,
class ALayout, class BLayout, class CLayout>
cute::tuple<AValueType*, BValueType*, CValueType*>
cublasdx::slice_shared_memory(void* smem_ptr,
ALayout a_layout,
BLayout b_layout,
CLayout c_layout);
// #3b: Slice shared memory with custom matrices layouts
template<class BLAS, class AValueType = typename BLAS::a_value_type,
class BValueType = typename BLAS::b_value_type,
class ALayout, class BLayout>
cute::tuple<AValueType*, BValueType*>
cublasdx::slice_shared_memory_ab(void* smem_ptr,
ALayout a_layout,
BLayout b_layout);
Methods cublasdx::slice_shared_memory(...)
and cublasdx::slice_shared_memory_ab(...)
slice shared memory into chunks, one for either A
, B
and C
or only A
and B
depending on the variant.
The return values are pointers to the first element of the slices for chosen matrices.
They follow the alignments in BLAS
description and at the same time, are not over-aligned, i.e., bytes between two slices are less
than the alignments.
Note that cublasdx::slice_shared_memory(...)
and cublasdx::slice_shared_memory_ab(...)
accept arbitrary CuTe layouts.
Class ALayout
, BLayout
(and CLayout
for cublasdx::slice_shared_memory(...)
) in the above function prototype could be either of
cute::Layout or
cute::ComposedLayout concepts.
Example
using BLAS = decltype(...);
extern __shared__ __align__(16) char smem[];
// use structured binding
auto [smem_a, smem_b, smem_c] = cublasdx::slice_shared_memory<BLAS>(smem);
// or
auto smem_slices = cublasdx::slice_shared_memory<BLAS>(smem);
auto smem_a = cute::get<0>(smem_slices);
auto smem_b = cute::get<1>(smem_slices);
auto smem_c = cute::get<2>(smem_slices);
cuBLASDx also offers an advanced generic slicing APIs allowing for arbitrary number of input matrices:
Note
The following functions are available in cublasdx::shared_memory
namespace not in cublasdx
namespace,
as they are shared between different Device Extensions libraries.
// 1. Use type, alignment and number of elements to slice into pointers
template <class... Ts, class... Args>
__host__ __device__ __forceinline__
cute::tuple<Ts*...> slice_into_pointers(void* smem, const Args... args);
// 2. Use type, alignment and layout to slice into tensors
template <class... Ts, class... Args>
__host__ __device__ __forceinline__
auto slice_into_tensors(void* smem, const Args... args);
// 3. Accept types, alignments and either count of elements or layouts to get
// appropriate combination of outputs
template <class... PointerTypes, class... Args>
__host__ __device__ __forceinline__
auto slice(void* smem, const Args... args);
The last variant offers the largest configurability in a single call for experienced users as it incorporates functionality of the first two API variants which can be seen as simplified helpers.
It can be used in the following way to properly slice shared memory into tensors for 2-stage pipelined GEMM execution:
// Slice shared memory into tensors for proper alignment in 2-stage pipelining
auto [s_a, s_b, s_a_n, s_b_n] =
cublasdx::shared_memory::slice<AValueType, BValueType, AValueType, BValueType>(
smem,
cublasdx::alignment_of_v_a<BLAS>, BLAS::suggest_layout_smem_a(),
cublasdx::alignment_of_v_b<BLAS>, BLAS::suggest_layout_smem_b(),
cublasdx::alignment_of_v_a<BLAS>, BLAS::suggest_layout_smem_a(),
cublasdx::alignment_of_v_b<BLAS>, BLAS::suggest_layout_smem_b()
);
Depending whether layout or count of elements is provided, the function will return either tensors or pointers.