Linear System Solver Using cuSolverDx#
In this introduction, we will perform a batched linear system solver based on Cholesky factorization using the cuSolverDx library. This section is based on the posv_batched.cu example shipped with cuSolverDx. Refer to Examples section for other cuSolverDx samples.
Defining a Function Descriptor#
NVIDIA MathDX uses a list of operators
to create a full description of the function to solve. Therefore, the first step is defining a function descriptor by adding together cuSolverDx operators that capture the properties of the function.
The correctness of this description is evaluated at compile time.
A well-defined cuSolverDx routine description must include two parts:
Selected linear algebra routine. In this case it is Cholesky decomposition and solver:
cusolverdx::function::posv
.Valid and sufficient description of the input/output matrices, including the dimensions (
m
,n
,nrhs
), the precision (float
ordouble
), the data type (real
orcomplex
), the fill mode (lower
orupper
), and the data arrangement of matrices (row-
orcolumn
-major).
To get a descriptor for Cholesky Solver with m = n = 32
, and nrhs = 1
, we just need to write the following lines:
#include <cusolverdx.hpp>
using namespace cusolverdx;
using Solver = decltype(Size<32 /* m = */ 32, /* n = */ 32, /* nrhs = */ 1>()
+ Precision<double>()
+ Type<type::complex>()
+ Function<function::posv>()
+ FillMode<lower>
+ Arrangement<col_major /* A and B */>());
In order to encode the operation properties, cuSolverDx provides operators
Size,
Precision,
Type,
Function,
FillMode, and
Arrangement,
which can be combined by using addition (+
).
Optionally, user can set leading dimensions for each matrix using LeadingDimension operator, or set them dynamically during the execution.
To obtain a fully usable operation that executes the function on CUDA block level, we need to provide at least two additional pieces of information:
The first one is the SM operator which indicates the targeted CUDA architecture on which we want to run the function. Each GPU architecture is different, therefore each may use a different implementation and require different CUDA block size for the best performance. In the posv_batched.cu example this is passed as template parameter, but in this example we can assume we’re targeting A100 GPUs (
SM<800>()
).Secondly, use the Block Operator to show that the routine will be performed by multiple threads in a single CUDA block. At this point, cuSolverDx performs additional verifications to make sure provided description is valid and that it is possible to execute it on the requested architecture.
#include <cusolverdx.hpp>
using namespace cusolverdx;
using Solver = decltype(Size<32 /* m = */ 32, /* n = */ 32, /* nrhs = */ 1>()
+ Precision<double>()
+ Type<type::complex>()
+ Function<function::posv>()
+ FillMode<lower>
+ Arrangement<col_major /* A, X, and B */>()
+ SM<800>()
+ Block());
Optionally user can also specify the number of threads that will be performing the Solver. This is done with the BlockDim operator. If BlockDim
operator is not used, cuSolverDx will select preferred block size that can be obtained with Solver::block_dim
.
Tip
If there is no need to set custom block dimensions, it is recommended not to use BlockDim
operator and rely on Solver::block_dim
. For more details, see Block Execute Method section, BlockDim, and Suggested block dim.
Executing the Function#
Class Solver
which describes the function can be instantiated into object (or objects).
Forming the object has no computational cost, and should be seen as a handle.
The function descriptor object provides a compute method, execute(...)
that performs the requested function.
#include <cusolverdx.hpp>
using namespace cusolverdx;
using Solver = decltype(Size<32 /* m = */ 32, /* n = */ 32, /* nrhs = */ 1>()
+ Precision<double>()
+ Type<type::complex>()
+ Function<function::posv>()
+ FillMode<lower>
+ Arrangement<col_major /* A, X, and B */>()
+ SM<800>()
+ Block());
__global__ void posv_kernel(double* A, double* B, typename Solver::status_type* info) {
// copy A and B from global memory to shared memory
// Execute Solver
Solver().execute(A_smem, B_smem, info);
// copy the output A and B from shared memory to global memory
}
If Cholesky factorization failed, i.e. some leading minor of A
is not positive definite, the output parameter info
would indicate smallest leading minor of A
which is not positive definite for each batch.
As shown in the comments, the execution method assumes that all the matrices reside in the shared memory. It is up to the users to load the matrices from global to shared memory before calling the execution method, and are responsible for saving the results. In the examples shipped with cuSolverDx, there are common utility functions of load and store provided for users’ convenience.
Launching the Kernel#
To launch a kernel executing the defined Solver
we need to know the required block dimensions and the amount of shared memory needed.
The required shared memory can be obtained using Solver::shared_memory_size
. It accounts for any padding declared using LeadingDimension Operator.
For simplicity, the example below shows only the main component of the code. Please check the full example, as well as others shipped with cuSolverDx.
#include <cusolverdx.hpp>
using namespace cusolverdx;
template<class POSV, class DataType = typename POSV::a_data_type>
__global__ void posv_kernel(DataType* A, const unsigned int lda_gmem, DataType* B,
const unsigned int ldb_gmem, typename POSV::status_type* info) {
constexpr auto m = POSV::m_size;
constexpr auto nrhs = POSV::nrhs;
const auto one_batch_size_a_gmem = lda_gmem * m;
const auto one_batch_size_b_gmem = (arrangement_of_v<POSV> == arrangement::col_major) ?
ldb_gmem * nrhs : m * ldb_gmem;
constexpr auto lda_smem = POSV::lda;
constexpr auto ldb_smem = POSV::ldb;
constexpr auto one_batch_size_a_smem = lda_smem * m;
extern __shared__ __align__(16) char shared_mem[];
DataType* As = reinterpret_cast<DataType*>(shared_mem);
DataType* Bs = As + one_batch_size_a_smem;
const auto batch_idx = blockIdx.x;
auto Ag = A + size_t(one_batch_size_a_gmem) * batch_idx;
auto Bg = B + size_t(one_batch_size_b_gmem) * batch_idx;
// Load data from global memory to shared memory
common::io<POSV>::load(Ag, lda_gmem, As, lda_smem);
common::io<POSV>::load_rhs(Bg, ldb_gmem, Bs, ldb_smem);
POSV().execute(As, lda_smem, Bs, ldb_smem, &info[batch_idx]);
// store data back to global memory
common::io<POSV>::store(As, lda_smem, Ag, lda_gmem);
common::io<POSV>::store_rhs(Bs, ldb_smem, Bg, ldb_gmem);
}
int posv_batched() {
using PSV = decltype(Size<32, 32, 1>()
+ Precision<double>()
+ Type<type::complex>()
+ Function<function::posv>()
+ FillMode<lower>
+ Arrangement<col_major /* A, X, and B */>()
+ SM<800>()
+ Block());
using data_type = typename POSV::a_data_type;
constexpr auto m = POSV::m_size;
constexpr auto n = POSV::n_size;
constexpr auto nrhs = POSV::nrhs;
static_assert(m == n, "potrf is for Hermitian positive-definite matrix matrix only");
constexpr auto lda_smem = POSV::lda;
constexpr auto ldb_smem = POSV::ldb;
constexpr auto lda = m;
constexpr auto ldb = (arrangement_of_v<POSV> == arrangement::col_major) ? m : nrhs;
const auto batches = 2;
const auto one_batch_size_A = lda * n;
const auto one_batch_size_B = m * nrhs;
std::vector<data_type> A(one_batch_size_A * batches);
// fill the A matrix
std::vector<data_type> B(one_batch_size_B * batches);
// fill the B matrix
std::vector<int> info(batches, 0);
data_type* d_A = nullptr; /* device copy of A */
data_type* d_B = nullptr; /* device copy of B */
int* d_info = nullptr; /* error info */
cudaMalloc(reinterpret_cast<void**>(&d_A), sizeof(data_type) * A.size());
cudaMalloc(reinterpret_cast<void**>(&d_B), sizeof(data_type) * B.size());
cudaMalloc(reinterpret_cast<void**>(&d_info), sizeof(int) * batches);
cudaMemcpy(d_A, A.data(), sizeof(data_type) * A.size(), cudaMemcpyHostToDevice);
cudaMemcp(d_B, B.data(), sizeof(data_type) * B.size(), cudaMemcpyHostToDevice);
const auto sm_size = POSV::shared_memory_size;
//Invokes kernel. Batches per block is 1
posv_kernel<POSV><<<batches, POSV::block_dim.x, sm_size>>>(d_A, lda, d_B, ldb, d_info);
cudaDeviceSynchronize();
cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_info);
}
It is important to note that unlike the cuSolver library, cuSolverDx does not require moving data back to global memory after executinga an operation, nor does it require the input data to be loaded from global memory. Those properties can be a major performance advantage for certain use-cases. The list of possible optimizations includes but is not limited to:
Fusing Solver functions with custom pre- and post-processing.
Fusing multiple Solver functions together.
Fusing Solver functions with BLAS, RAND, or FFT operations.
Compilation#
For instructions on how to compile programs with cuSolverDx see Installation Guide.