Just In Time (JIT) Compilation¶
This section introduces the Just-In-Time (JIT) Compilation feature. This feature allows users to compile specialized kernels, to maximize performance for a specific operation.
The complexity of a given contraction (e.g., number and order of modes, number of contracted modes), determines the size of its kernel search space (i.e., the set of candidate kernels that can be used to perform the contraction). As the complexity increases, the search space can become prohibitively large. Our pre-compiled kernels are carefully selected to perform well on a wide variety of different contractions. However, as the complexity of contractions increases, a just-in-time compiled kernel that is tailored for the given contraction can outperform a kernel from the fixed-sized set of pre-compiled kernels.
JIT compilation overcomes this limitation by creating a kernel that better utilizes the optimization opportunities that are applicable to a given contraction.
The cost of compiling a kernel typically amounts to 1-8 seconds (depending on the kernel and the host CPU); this cost occurs just once per kernel during the planning stage; the kernel can be reused by subsequent contractions (i.e., kernels are automatically cached once they are compiled).
All JIT compiled kernels get added in the kernel cache, which is accessible by the whole library (i.e., shared across library handles). We provide functions to read and write the kernel cache to disk, to avoid the cost of JIT compiling the same kernels when re-running a program.
The remainder of this section assumes familiarity with Getting Started.
Note
The JIT compilation feature is only supported for GPUs with compute capability greater or equal to 8.0 (Ampere or newer). Moreover, this feature is currently limited to tensor contractions.
Introductory Example¶
This subsection provides a basic overview of the API calls and features related to JIT compilation.
We begin by computing a contraction using the same steps as described in Getting Started, but with a different contraction example to emphasize the benefit of JIT compilation when the number of contracted modes increases. Then, we describe the necessary modifications to enable JIT compilation, and compare the performance of the pre-compiled and JIT-compiled kernels.
Our code computes the following operation (note that we now use numbers instead of letters to name each mode, since the number of modes exceeds the letters of the latin alphabet):
All operands contain single-precision complex values, and computation is performed with emulated single-precision arithmetic (3XTF32). All modes have an extent of 2.
The steps to compute the above contraction are the same as the ones in Getting Started, and are listed below:
#include <chrono>
#include <complex>
#include <stdlib.h>
#include <stdio.h>
#include <unordered_map>
#include <vector>
#include <cuda_runtime.h>
#include <cutensor.h>
// Handle cuTENSOR errors
#define HANDLE_ERROR(x) { \
const auto err = x; \
if( err != CUTENSOR_STATUS_SUCCESS ) \
{ printf("Error: %s in line %d\n", cutensorGetErrorString(err), __LINE__); exit(-1); } \
};
// Handle CUDA errors
#define HANDLE_CUDA_ERROR(x) { \
const auto err = x; \
if( err != cudaSuccess ) \
{ printf("Error: %s in line %d\n", cudaGetErrorString(err), __LINE__); exit(-1); } \
};
class CPUTimer
{
public:
void start()
{
start_ = std::chrono::steady_clock::now();
}
double seconds()
{
end_ = std::chrono::steady_clock::now();
elapsed_ = end_ - start_;
//return in ms
return elapsed_.count() * 1000;
}
private:
typedef std::chrono::steady_clock::time_point tp;
tp start_;
tp end_;
std::chrono::duration<double> elapsed_;
};
struct GPUTimer
{
GPUTimer()
{
cudaEventCreate(&start_);
cudaEventCreate(&stop_);
cudaEventRecord(start_, 0);
}
~GPUTimer()
{
cudaEventDestroy(start_);
cudaEventDestroy(stop_);
}
void start()
{
cudaEventRecord(start_, 0);
}
float seconds()
{
cudaEventRecord(stop_, 0);
cudaEventSynchronize(stop_);
float time;
cudaEventElapsedTime(&time, start_, stop_);
return time * 1e-3;
}
private:
cudaEvent_t start_, stop_;
};
int main()
{
typedef std::complex<float> TypeA;
typedef std::complex<float> TypeB;
typedef std::complex<float> TypeC;
typedef std::complex<float> TypeScalar;
auto alpha = TypeScalar(1.1, 0.0);
auto beta = TypeScalar(0.0, 0.0);
// CUDA types
cutensorDataType_t typeA = CUTENSOR_C_32F;
cutensorDataType_t typeB = CUTENSOR_C_32F;
cutensorDataType_t typeC = CUTENSOR_C_32F;
cutensorComputeDescriptor_t descCompute = CUTENSOR_COMPUTE_DESC_3XTF32;
/* ***************************** */
// Create vector of modes
std::vector<int> modeC{0,1,2,3,4,6,8,9,25,26,10,12,14,27,15,28,17,19,29,20,21,30,23,24};
std::vector<int> modeA{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24};
std::vector<int> modeB{25,26,27,28,29,30,5,7,11,13,16,18,22};
int nmodeA = modeA.size();
int nmodeB = modeB.size();
int nmodeC = modeC.size();
// Extents
std::unordered_map<int, int64_t> extent;
for (auto i = 0; i <= 30; i++)
extent[i] = 2;
// Create a vector of extents for each tensor
std::vector<int64_t> extentC;
for (auto mode : modeC)
extentC.push_back(extent[mode]);
std::vector<int64_t> extentA;
for (auto mode : modeA)
extentA.push_back(extent[mode]);
std::vector<int64_t> extentB;
for (auto mode : modeB)
extentB.push_back(extent[mode]);
/**********************
* Allocating data
**********************/
// Number of elements of each tensor
size_t elementsA = 1;
for (auto mode : modeA)
elementsA *= extent[mode];
size_t elementsB = 1;
for (auto mode : modeB)
elementsB *= extent[mode];
size_t elementsC = 1;
for (auto mode : modeC)
elementsC *= extent[mode];
// Size in bytes
size_t sizeA = sizeof(TypeA) * elementsA;
size_t sizeB = sizeof(TypeB) * elementsB;
size_t sizeC = sizeof(TypeC) * elementsC;
printf("Total memory: %.2f GiB\n", (sizeA + sizeB + sizeC)/1024./1024./1024);
// Allocate on device
void *A_d, *B_d, *C_d;
HANDLE_CUDA_ERROR(cudaMalloc((void**) &A_d, sizeA));
HANDLE_CUDA_ERROR(cudaMalloc((void**) &B_d, sizeB));
HANDLE_CUDA_ERROR(cudaMalloc((void**) &C_d, sizeC));
// Allocate on host
TypeA *A = (TypeA*) malloc(sizeof(TypeA) * elementsA);
TypeB *B = (TypeB*) malloc(sizeof(TypeB) * elementsB);
TypeC *C = (TypeC*) malloc(sizeof(TypeC) * elementsC);
if (A == nullptr || B == nullptr || C == nullptr)
{
printf("Error: Host allocation of A, B, or C.\n");
exit(-1);
}
/*******************
* Initialize data
*******************/
for (int64_t i = 0; i < elementsA; i++)
A[i] = (((float) rand())/RAND_MAX - 0.5)*100;
for (int64_t i = 0; i < elementsB; i++)
B[i] = (((float) rand())/RAND_MAX - 0.5)*100;
for (int64_t i = 0; i < elementsC; i++)
C[i] = (((float) rand())/RAND_MAX - 0.5)*100;
// Copy to device
HANDLE_CUDA_ERROR(cudaMemcpy(A_d, A, sizeA, cudaMemcpyHostToDevice));
HANDLE_CUDA_ERROR(cudaMemcpy(B_d, B, sizeB, cudaMemcpyHostToDevice));
HANDLE_CUDA_ERROR(cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice));
const uint32_t kAlignment = 128; // Alignment of the global-memory device pointers (bytes)
/*************************
* cuTENSOR
*************************/
// Initialize cuTENSOR library
cutensorHandle_t handle;
HANDLE_ERROR(cutensorCreate(&handle));
/**********************
* Create Tensor Descriptors
**********************/
cutensorTensorDescriptor_t descA;
HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
&descA,
nmodeA,
extentA.data(),
NULL,/*stride*/
typeA, kAlignment));
cutensorTensorDescriptor_t descB;
HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
&descB,
nmodeB,
extentB.data(),
NULL,/*stride*/
typeB, kAlignment));
cutensorTensorDescriptor_t descC;
HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
&descC,
nmodeC,
extentC.data(),
NULL,/*stride*/
typeC, kAlignment));
/*******************************
* Create Contraction Descriptor
*******************************/
cutensorOperationDescriptor_t desc;
HANDLE_ERROR(cutensorCreateContraction(handle,
&desc,
descA, modeA.data(), /* unary operator A*/CUTENSOR_OP_IDENTITY,
descB, modeB.data(), /* unary operator B*/CUTENSOR_OP_IDENTITY,
descC, modeC.data(), /* unary operator C*/CUTENSOR_OP_IDENTITY,
descC, modeC.data(),
descCompute));
/**************************
* Set the algorithm to use -- without just-in-time compilation
***************************/
const cutensorAlgo_t algo = CUTENSOR_ALGO_GETT;
cutensorPlanPreference_t planPref;
HANDLE_ERROR(cutensorCreatePlanPreference(handle,
&planPref,
algo,
CUTENSOR_JIT_MODE_NONE));
/**********************
* Query workspace estimate
**********************/
uint64_t workspaceSizeEstimate = 0;
const cutensorWorksizePreference_t workspacePref = CUTENSOR_WORKSPACE_DEFAULT;
HANDLE_ERROR(cutensorEstimateWorkspaceSize(handle,
desc,
planPref,
workspacePref,
&workspaceSizeEstimate));
// Allocate workspace
void *work = nullptr;
if (workspaceSizeEstimate > 0)
{
HANDLE_CUDA_ERROR(cudaMalloc(&work, workspaceSizeEstimate));
}
/**************************
* Create Contraction Plan -- without just-in-time compilation
**************************/
cutensorPlan_t plan;
HANDLE_ERROR(cutensorCreatePlan(handle,
&plan,
desc,
planPref,
workspaceSizeEstimate));
/**********************
* Execute the tensor contraction
**********************/
cudaStream_t stream;
HANDLE_CUDA_ERROR(cudaStreamCreate(&stream));
double minTimeCUTENSOR = 1e100;
for (int i=0; i < 3; ++i)
{
cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice);
// Set up timing
GPUTimer timer;
timer.start();
HANDLE_ERROR(cutensorContract(handle,
plan,
(void*) &alpha, A_d, B_d,
(void*) &beta, C_d, C_d,
work, workspaceSizeEstimate, stream))
// Synchronize and measure timing
auto time = timer.seconds();
minTimeCUTENSOR = (minTimeCUTENSOR < time) ? minTimeCUTENSOR : time;
}
/*************************/
float flops;
HANDLE_ERROR(cutensorOperationDescriptorGetAttribute(handle,
desc,
CUTENSOR_OPERATION_DESCRIPTOR_FLOPS,
(void*)&flops,
sizeof(flops)));
auto gflops = flops / 1e9;
auto gflopsPerSec = gflops / minTimeCUTENSOR;
printf("cuTENSOR : %6.0f GFLOPs/s\n", gflopsPerSec);
HANDLE_ERROR(cutensorDestroy(handle));
HANDLE_ERROR(cutensorDestroyOperationDescriptor(desc));
HANDLE_ERROR(cutensorDestroyTensorDescriptor(descA));
HANDLE_ERROR(cutensorDestroyTensorDescriptor(descB));
HANDLE_ERROR(cutensorDestroyTensorDescriptor(descC));
HANDLE_CUDA_ERROR(cudaStreamDestroy(stream));
HANDLE_ERROR(cutensorDestroyPlanPreference(planPref));
HANDLE_ERROR(cutensorDestroyPlan(plan));
if (A) free(A);
if (B) free(B);
if (C) free(C);
if (A_d) cudaFree(A_d);
if (B_d) cudaFree(B_d);
if (C_d) cudaFree(C_d);
if (work) cudaFree(work);
printf("Successful completion\n");
return 0;
}
All it takes to enable JIT compilation is to change the last argument of cutensorCreatePlanPreference() from CUTENSOR_JIT_MODE_NONE
to CUTENSOR_JIT_MODE_DEFAULT
–no further changes are required:
cutensorPlanPreference_t planPrefJit;
cutensorCreatePlanPreference(handle,
&planPrefJit,
algo,
CUTENSOR_JIT_MODE_DEFAULT);
The kernel is compiled during the call to cutensorCreatePlan(). This call is blocking and the compilation process can take up to a few seconds. Once the plan has been created, the kernel is compiled and stored in the kernel cache (see Reading and writing the kernel cache to disk) to be used via a call to cutensorContract().
Note
To re-use a JIT compiled kernel in a subsequent contraction (using a different plan) the user must, again, set CUTENSOR_JIT_MODE_DEFAULT
during cutensorCreatePlanPreference(). Otherwise, a pre-compiled kernel will be used.
The full working example is as follows:
#include <chrono>
#include <complex>
#include <stdlib.h>
#include <stdio.h>
#include <unordered_map>
#include <vector>
#include <cuda_runtime.h>
#include <cutensor.h>
// Handle cuTENSOR errors
#define HANDLE_ERROR(x) { \
const auto err = x; \
if( err != CUTENSOR_STATUS_SUCCESS ) \
{ printf("Error: %s in line %d\n", cutensorGetErrorString(err), __LINE__); exit(-1); } \
};
// Handle CUDA errors
#define HANDLE_CUDA_ERROR(x) { \
const auto err = x; \
if( err != cudaSuccess ) \
{ printf("Error: %s in line %d\n", cudaGetErrorString(err), __LINE__); exit(-1); } \
};
class CPUTimer
{
public:
void start()
{
start_ = std::chrono::steady_clock::now();
}
double seconds()
{
end_ = std::chrono::steady_clock::now();
elapsed_ = end_ - start_;
//return in ms
return elapsed_.count() * 1000;
}
private:
typedef std::chrono::steady_clock::time_point tp;
tp start_;
tp end_;
std::chrono::duration<double> elapsed_;
};
struct GPUTimer
{
GPUTimer()
{
cudaEventCreate(&start_);
cudaEventCreate(&stop_);
cudaEventRecord(start_, 0);
}
~GPUTimer()
{
cudaEventDestroy(start_);
cudaEventDestroy(stop_);
}
void start()
{
cudaEventRecord(start_, 0);
}
float seconds()
{
cudaEventRecord(stop_, 0);
cudaEventSynchronize(stop_);
float time;
cudaEventElapsedTime(&time, start_, stop_);
return time * 1e-3;
}
private:
cudaEvent_t start_, stop_;
};
int main()
{
typedef std::complex<float> TypeA;
typedef std::complex<float> TypeB;
typedef std::complex<float> TypeC;
typedef std::complex<float> TypeScalar;
auto alpha = TypeScalar(1.1, 0.0);
auto beta = TypeScalar(0.0, 0.0);
// CUDA types
cutensorDataType_t typeA = CUTENSOR_C_32F;
cutensorDataType_t typeB = CUTENSOR_C_32F;
cutensorDataType_t typeC = CUTENSOR_C_32F;
cutensorComputeDescriptor_t descCompute = CUTENSOR_COMPUTE_DESC_3XTF32;
/* ***************************** */
// Create vector of modes
std::vector<int> modeC{0,1,2,3,4,6,8,9,25,26,10,12,14,27,15,28,17,19,29,20,21,30,23,24};
std::vector<int> modeA{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24};
std::vector<int> modeB{25,26,27,28,29,30,5,7,11,13,16,18,22};
int nmodeA = modeA.size();
int nmodeB = modeB.size();
int nmodeC = modeC.size();
// Extents
std::unordered_map<int, int64_t> extent;
for (auto i = 0; i <= 30; i++)
extent[i] = 2;
// Create a vector of extents for each tensor
std::vector<int64_t> extentC;
for (auto mode : modeC)
extentC.push_back(extent[mode]);
std::vector<int64_t> extentA;
for (auto mode : modeA)
extentA.push_back(extent[mode]);
std::vector<int64_t> extentB;
for (auto mode : modeB)
extentB.push_back(extent[mode]);
/**********************
* Allocating data
**********************/
// Number of elements of each tensor
size_t elementsA = 1;
for (auto mode : modeA)
elementsA *= extent[mode];
size_t elementsB = 1;
for (auto mode : modeB)
elementsB *= extent[mode];
size_t elementsC = 1;
for (auto mode : modeC)
elementsC *= extent[mode];
// Size in bytes
size_t sizeA = sizeof(TypeA) * elementsA;
size_t sizeB = sizeof(TypeB) * elementsB;
size_t sizeC = sizeof(TypeC) * elementsC;
printf("Total memory: %.2f GiB\n", (sizeA + sizeB + sizeC)/1024./1024./1024);
// Allocate on device
void *A_d, *B_d, *C_d;
HANDLE_CUDA_ERROR(cudaMalloc((void**) &A_d, sizeA));
HANDLE_CUDA_ERROR(cudaMalloc((void**) &B_d, sizeB));
HANDLE_CUDA_ERROR(cudaMalloc((void**) &C_d, sizeC));
// Allocate on host
TypeA *A = (TypeA*) malloc(sizeof(TypeA) * elementsA);
TypeB *B = (TypeB*) malloc(sizeof(TypeB) * elementsB);
TypeC *C = (TypeC*) malloc(sizeof(TypeC) * elementsC);
if (A == nullptr || B == nullptr || C == nullptr)
{
printf("Error: Host allocation of A, B, or C.\n");
exit(-1);
}
/*******************
* Initialize data
*******************/
for (int64_t i = 0; i < elementsA; i++)
A[i] = (((float) rand())/RAND_MAX - 0.5)*100;
for (int64_t i = 0; i < elementsB; i++)
B[i] = (((float) rand())/RAND_MAX - 0.5)*100;
for (int64_t i = 0; i < elementsC; i++)
C[i] = (((float) rand())/RAND_MAX - 0.5)*100;
// Copy to device
HANDLE_CUDA_ERROR(cudaMemcpy(A_d, A, sizeA, cudaMemcpyHostToDevice));
HANDLE_CUDA_ERROR(cudaMemcpy(B_d, B, sizeB, cudaMemcpyHostToDevice));
HANDLE_CUDA_ERROR(cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice));
const uint32_t kAlignment = 128; // Alignment of the global-memory device pointers (bytes)
/*************************
* cuTENSOR
*************************/
// Initialize cuTENSOR library
cutensorHandle_t handle;
HANDLE_ERROR(cutensorCreate(&handle));
/**********************
* Create Tensor Descriptors
**********************/
cutensorTensorDescriptor_t descA;
HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
&descA,
nmodeA,
extentA.data(),
NULL,/*stride*/
typeA, kAlignment));
cutensorTensorDescriptor_t descB;
HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
&descB,
nmodeB,
extentB.data(),
NULL,/*stride*/
typeB, kAlignment));
cutensorTensorDescriptor_t descC;
HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
&descC,
nmodeC,
extentC.data(),
NULL,/*stride*/
typeC, kAlignment));
/*******************************
* Create Contraction Descriptor
*******************************/
cutensorOperationDescriptor_t desc;
HANDLE_ERROR(cutensorCreateContraction(handle,
&desc,
descA, modeA.data(), /* unary operator A*/CUTENSOR_OP_IDENTITY,
descB, modeB.data(), /* unary operator B*/CUTENSOR_OP_IDENTITY,
descC, modeC.data(), /* unary operator C*/CUTENSOR_OP_IDENTITY,
descC, modeC.data(),
descCompute));
/**************************
* Set the algorithm to use -- without just-in-time compilation
***************************/
const cutensorAlgo_t algo = CUTENSOR_ALGO_GETT;
cutensorPlanPreference_t planPref;
HANDLE_ERROR(cutensorCreatePlanPreference(handle,
&planPref,
algo,
CUTENSOR_JIT_MODE_NONE));
/**********************
* Query workspace estimate
**********************/
uint64_t workspaceSizeEstimate = 0;
const cutensorWorksizePreference_t workspacePref = CUTENSOR_WORKSPACE_DEFAULT;
HANDLE_ERROR(cutensorEstimateWorkspaceSize(handle,
desc,
planPref,
workspacePref,
&workspaceSizeEstimate));
// Allocate workspace
void *work = nullptr;
if (workspaceSizeEstimate > 0)
{
HANDLE_CUDA_ERROR(cudaMalloc(&work, workspaceSizeEstimate));
}
/**************************
* Create Contraction Plan -- without just-in-time compilation
**************************/
cutensorPlan_t plan;
HANDLE_ERROR(cutensorCreatePlan(handle,
&plan,
desc,
planPref,
workspaceSizeEstimate));
/**********************
* Execute the tensor contraction
**********************/
cudaStream_t stream;
HANDLE_CUDA_ERROR(cudaStreamCreate(&stream));
double minTimeCUTENSOR = 1e100;
for (int i=0; i < 3; ++i)
{
cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice);
// Set up timing
GPUTimer timer;
timer.start();
HANDLE_ERROR(cutensorContract(handle,
plan,
(void*) &alpha, A_d, B_d,
(void*) &beta, C_d, C_d,
work, workspaceSizeEstimate, stream))
// Synchronize and measure timing
auto time = timer.seconds();
minTimeCUTENSOR = (minTimeCUTENSOR < time) ? minTimeCUTENSOR : time;
}
/*************************/
/**************************
* Set the algorithm to use -- with just-in-time compilation
***************************/
cutensorPlanPreference_t planPrefJit;
HANDLE_ERROR(cutensorCreatePlanPreference(handle,
&planPrefJit,
algo,
CUTENSOR_JIT_MODE_DEFAULT));
/**********************
* Query workspace estimate
**********************/
uint64_t workspaceSizeEstimateJit = 0;
const cutensorWorksizePreference_t workspacePrefJit = CUTENSOR_WORKSPACE_DEFAULT;
HANDLE_ERROR(cutensorEstimateWorkspaceSize(handle,
desc,
planPrefJit,
workspacePrefJit,
&workspaceSizeEstimateJit));
// Allocate workspace
void *workJit = nullptr;
if (workspaceSizeEstimateJit > 0)
{
HANDLE_CUDA_ERROR(cudaMalloc(&workJit, workspaceSizeEstimateJit));
}
/**************************
* Create Contraction Plan -- with just-in-time compilation
**************************/
cutensorPlan_t planJit;
CPUTimer jitPlanTimer;
jitPlanTimer.start();
// This is where the kernel is actually compiled
HANDLE_ERROR(cutensorCreatePlan(handle,
&planJit,
desc,
planPrefJit,
workspaceSizeEstimateJit));
auto jitPlanTime = jitPlanTimer.seconds();
/**********************
* Execute the tensor contraction using the JIT compiled kernel
**********************/
double minTimeCUTENSORJit = 1e100;
for (int i=0; i < 3; ++i)
{
cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice);
// Set up timing
GPUTimer timer;
timer.start();
HANDLE_ERROR(cutensorContract(handle,
planJit,
(void*) &alpha, A_d, B_d,
(void*) &beta, C_d, C_d,
workJit, workspaceSizeEstimateJit, stream))
// Synchronize and measure timing
auto time = timer.seconds();
minTimeCUTENSORJit = (minTimeCUTENSORJit < time) ? minTimeCUTENSORJit : time;
}
/*************************/
float flops;
HANDLE_ERROR(cutensorOperationDescriptorGetAttribute(handle,
desc,
CUTENSOR_OPERATION_DESCRIPTOR_FLOPS,
(void*)&flops,
sizeof(flops)));
auto gflops = flops / 1e9;
auto gflopsPerSec = gflops / minTimeCUTENSOR;
auto gflopsPerSecJit = gflops / minTimeCUTENSORJit;
printf("cuTENSOR : %6.0f GFLOPs/s\n", gflopsPerSec);
printf("cuTENSOR JIT: %6.0f GFLOPs/s\n", gflopsPerSecJit);
printf("Speedup: %.1fx\n", gflopsPerSecJit / gflopsPerSec);
printf("JIT Compilation time: %.1f seconds\n", jitPlanTime / 1e3);
HANDLE_ERROR(cutensorDestroy(handle));
HANDLE_ERROR(cutensorDestroyOperationDescriptor(desc));
HANDLE_ERROR(cutensorDestroyTensorDescriptor(descA));
HANDLE_ERROR(cutensorDestroyTensorDescriptor(descB));
HANDLE_ERROR(cutensorDestroyTensorDescriptor(descC));
HANDLE_CUDA_ERROR(cudaStreamDestroy(stream));
HANDLE_ERROR(cutensorDestroyPlanPreference(planPref));
HANDLE_ERROR(cutensorDestroyPlan(plan));
HANDLE_ERROR(cutensorDestroyPlanPreference(planPrefJit));
HANDLE_ERROR(cutensorDestroyPlan(planJit));
if (A) free(A);
if (B) free(B);
if (C) free(C);
if (A_d) cudaFree(A_d);
if (B_d) cudaFree(B_d);
if (C_d) cudaFree(C_d);
if (work) cudaFree(work);
if (workJit) cudaFree(workJit);
printf("Successful completion\n");
return 0;
}
Below is the output of the above program on an NVIDIA H100 PCIe GPU:
cuTENSOR : 774 GFLOPs/s
cuTENSOR JIT: 5374 GFLOPs/s
Speedup: 6.9x
JIT Compilation time: 8.3 seconds
This concludes the introductory example.
Reading and writing the kernel cache to disk¶
JIT compilation can take a significant amount of time, especially if applied to tens or hundreds of different contractions. To amortize this overhead, we provide functions to write the kernel cache to disk once all plans have been created. This way, compilation cost occurs only once for each plan across different executions of the program.
Note
A kernel cache file stores information about the version of cuTENSOR that was used (CUTENSOR_VERSION), the version of CUDA on the system (CUDA_VERSION), and the model of the GPU (GPU_MODEL). For a kernel cache file to be read successfully, all three of these values must match exactly on the target system.
To write a kernel cache to file use the cutensorWriteKernelCacheToFile() function, after all plans that use JIT compilation have been created.
cutensorWriteKernelCacheToFile(handle, "kernelCache.bin")
To read the file and load the kernels into a running instance of cuTENSOR, simply use the cutensorReadKernelCacheFromFile() function:
cutensorReadKernelCacheFromFile(handle, "kernelCache.bin")
Note
After reading a kernel cache from file, users must still enable JIT compilation during cutensorCreatePlanPreference() (by providing the CUTENSOR_JIT_MODE_DEFAULT
argument) for the contractions that ought to use the previously-JIT compiled kernels.
Below, we repeat the Introductory Example but in line 188 we check to see if a kernel cache file can be read, and in line 399 we write the kernel cache to a file. On the second execution of the below sample, the kernel cache is read and compilation is avoided.
1#include <chrono>
2#include <complex>
3#include <stdlib.h>
4#include <stdio.h>
5#include <unordered_map>
6#include <vector>
7
8#include <cuda_runtime.h>
9#include <cutensor.h>
10
11// Handle cuTENSOR errors
12#define HANDLE_ERROR(x) { \
13 const auto err = x; \
14 if( err != CUTENSOR_STATUS_SUCCESS ) \
15 { printf("Error: %s in line %d\n", cutensorGetErrorString(err), __LINE__); exit(-1); } \
16};
17
18// Handle CUDA errors
19#define HANDLE_CUDA_ERROR(x) { \
20 const auto err = x; \
21 if( err != cudaSuccess ) \
22 { printf("Error: %s in line %d\n", cudaGetErrorString(err), __LINE__); exit(-1); } \
23};
24
25class CPUTimer
26{
27public:
28 void start()
29 {
30 start_ = std::chrono::steady_clock::now();
31 }
32
33 double seconds()
34 {
35 end_ = std::chrono::steady_clock::now();
36 elapsed_ = end_ - start_;
37 //return in ms
38 return elapsed_.count() * 1000;
39 }
40
41private:
42 typedef std::chrono::steady_clock::time_point tp;
43 tp start_;
44 tp end_;
45 std::chrono::duration<double> elapsed_;
46};
47
48struct GPUTimer
49{
50 GPUTimer()
51 {
52 cudaEventCreate(&start_);
53 cudaEventCreate(&stop_);
54 cudaEventRecord(start_, 0);
55 }
56
57 ~GPUTimer()
58 {
59 cudaEventDestroy(start_);
60 cudaEventDestroy(stop_);
61 }
62
63 void start()
64 {
65 cudaEventRecord(start_, 0);
66 }
67
68 float seconds()
69 {
70 cudaEventRecord(stop_, 0);
71 cudaEventSynchronize(stop_);
72 float time;
73 cudaEventElapsedTime(&time, start_, stop_);
74 return time * 1e-3;
75 }
76 private:
77 cudaEvent_t start_, stop_;
78};
79
80int main()
81{
82 typedef std::complex<float> TypeA;
83 typedef std::complex<float> TypeB;
84 typedef std::complex<float> TypeC;
85 typedef std::complex<float> TypeScalar;
86
87 auto alpha = TypeScalar(1.1, 0.0);
88 auto beta = TypeScalar(0.0, 0.0);
89
90 // CUDA types
91 cutensorDataType_t typeA = CUTENSOR_C_32F;
92 cutensorDataType_t typeB = CUTENSOR_C_32F;
93 cutensorDataType_t typeC = CUTENSOR_C_32F;
94 cutensorComputeDescriptor_t descCompute = CUTENSOR_COMPUTE_DESC_3XTF32;
95
96
97 /* ***************************** */
98
99 // Create vector of modes
100 std::vector<int> modeC{0,1,2,3,4,6,8,9,25,26,10,12,14,27,15,28,17,19,29,20,21,30,23,24};
101 std::vector<int> modeA{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24};
102 std::vector<int> modeB{25,26,27,28,29,30,5,7,11,13,16,18,22};
103 int nmodeA = modeA.size();
104 int nmodeB = modeB.size();
105 int nmodeC = modeC.size();
106
107 // Extents
108 std::unordered_map<int, int64_t> extent;
109 for (auto i = 0; i <= 30; i++)
110 extent[i] = 2;
111
112 // Create a vector of extents for each tensor
113 std::vector<int64_t> extentC;
114 for (auto mode : modeC)
115 extentC.push_back(extent[mode]);
116 std::vector<int64_t> extentA;
117 for (auto mode : modeA)
118 extentA.push_back(extent[mode]);
119 std::vector<int64_t> extentB;
120 for (auto mode : modeB)
121 extentB.push_back(extent[mode]);
122
123 /**********************
124 * Allocating data
125 **********************/
126
127 // Number of elements of each tensor
128 size_t elementsA = 1;
129 for (auto mode : modeA)
130 elementsA *= extent[mode];
131 size_t elementsB = 1;
132 for (auto mode : modeB)
133 elementsB *= extent[mode];
134 size_t elementsC = 1;
135 for (auto mode : modeC)
136 elementsC *= extent[mode];
137
138 // Size in bytes
139 size_t sizeA = sizeof(TypeA) * elementsA;
140 size_t sizeB = sizeof(TypeB) * elementsB;
141 size_t sizeC = sizeof(TypeC) * elementsC;
142 printf("Total memory: %.2f GiB\n", (sizeA + sizeB + sizeC)/1024./1024./1024);
143
144 // Allocate on device
145 void *A_d, *B_d, *C_d;
146 HANDLE_CUDA_ERROR(cudaMalloc((void**) &A_d, sizeA));
147 HANDLE_CUDA_ERROR(cudaMalloc((void**) &B_d, sizeB));
148 HANDLE_CUDA_ERROR(cudaMalloc((void**) &C_d, sizeC));
149
150 // Allocate on host
151 TypeA *A = (TypeA*) malloc(sizeof(TypeA) * elementsA);
152 TypeB *B = (TypeB*) malloc(sizeof(TypeB) * elementsB);
153 TypeC *C = (TypeC*) malloc(sizeof(TypeC) * elementsC);
154
155 if (A == nullptr || B == nullptr || C == nullptr)
156 {
157 printf("Error: Host allocation of A, B, or C.\n");
158 exit(-1);
159 }
160
161 /*******************
162 * Initialize data
163 *******************/
164
165 for (int64_t i = 0; i < elementsA; i++)
166 A[i] = (((float) rand())/RAND_MAX - 0.5)*100;
167 for (int64_t i = 0; i < elementsB; i++)
168 B[i] = (((float) rand())/RAND_MAX - 0.5)*100;
169 for (int64_t i = 0; i < elementsC; i++)
170 C[i] = (((float) rand())/RAND_MAX - 0.5)*100;
171
172 // Copy to device
173 HANDLE_CUDA_ERROR(cudaMemcpy(A_d, A, sizeA, cudaMemcpyHostToDevice));
174 HANDLE_CUDA_ERROR(cudaMemcpy(B_d, B, sizeB, cudaMemcpyHostToDevice));
175 HANDLE_CUDA_ERROR(cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice));
176
177 const uint32_t kAlignment = 128; // Alignment of the global-memory device pointers (bytes)
178
179 /*************************
180 * cuTENSOR
181 *************************/
182
183 // Initialize cuTENSOR library
184 cutensorHandle_t handle;
185 HANDLE_ERROR(cutensorCreate(&handle));
186
187 // Read kernel cache from file (if the file was generated by a prior execution)
188 auto readKernelCacheStatus = cutensorReadKernelCacheFromFile(handle, "kernelCache.bin");
189
190 if (readKernelCacheStatus == CUTENSOR_STATUS_IO_ERROR)
191 printf("No kernel cache found. It will be generated before the end of this execution.\n");
192 else if (readKernelCacheStatus == CUTENSOR_STATUS_SUCCESS)
193 printf("Kernel cache found and read successfully.\n");
194 else
195 HANDLE_ERROR(readKernelCacheStatus);
196
197 /**********************
198 * Create Tensor Descriptors
199 **********************/
200
201 cutensorTensorDescriptor_t descA;
202 HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
203 &descA,
204 nmodeA,
205 extentA.data(),
206 NULL,/*stride*/
207 typeA, kAlignment));
208
209 cutensorTensorDescriptor_t descB;
210 HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
211 &descB,
212 nmodeB,
213 extentB.data(),
214 NULL,/*stride*/
215 typeB, kAlignment));
216
217 cutensorTensorDescriptor_t descC;
218 HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
219 &descC,
220 nmodeC,
221 extentC.data(),
222 NULL,/*stride*/
223 typeC, kAlignment));
224
225 /*******************************
226 * Create Contraction Descriptor
227 *******************************/
228
229 cutensorOperationDescriptor_t desc;
230 HANDLE_ERROR(cutensorCreateContraction(handle,
231 &desc,
232 descA, modeA.data(), /* unary operator A*/CUTENSOR_OP_IDENTITY,
233 descB, modeB.data(), /* unary operator B*/CUTENSOR_OP_IDENTITY,
234 descC, modeC.data(), /* unary operator C*/CUTENSOR_OP_IDENTITY,
235 descC, modeC.data(),
236 descCompute));
237
238 /**************************
239 * Set the algorithm to use -- without just-in-time compilation
240 ***************************/
241
242 const cutensorAlgo_t algo = CUTENSOR_ALGO_GETT;
243
244 cutensorPlanPreference_t planPref;
245 HANDLE_ERROR(cutensorCreatePlanPreference(handle,
246 &planPref,
247 algo,
248 CUTENSOR_JIT_MODE_NONE));
249
250 /**********************
251 * Query workspace estimate
252 **********************/
253
254 uint64_t workspaceSizeEstimate = 0;
255 const cutensorWorksizePreference_t workspacePref = CUTENSOR_WORKSPACE_DEFAULT;
256 HANDLE_ERROR(cutensorEstimateWorkspaceSize(handle,
257 desc,
258 planPref,
259 workspacePref,
260 &workspaceSizeEstimate));
261 // Allocate workspace
262 void *work = nullptr;
263 if (workspaceSizeEstimate > 0)
264 {
265 HANDLE_CUDA_ERROR(cudaMalloc(&work, workspaceSizeEstimate));
266 }
267
268 /**************************
269 * Create Contraction Plan -- without just-in-time compilation
270 **************************/
271
272 cutensorPlan_t plan;
273 HANDLE_ERROR(cutensorCreatePlan(handle,
274 &plan,
275 desc,
276 planPref,
277 workspaceSizeEstimate));
278
279 /**********************
280 * Execute the tensor contraction
281 **********************/
282
283 cudaStream_t stream;
284 HANDLE_CUDA_ERROR(cudaStreamCreate(&stream));
285
286 double minTimeCUTENSOR = 1e100;
287 for (int i=0; i < 3; ++i)
288 {
289 cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice);
290
291 // Set up timing
292 GPUTimer timer;
293 timer.start();
294
295 HANDLE_ERROR(cutensorContract(handle,
296 plan,
297 (void*) &alpha, A_d, B_d,
298 (void*) &beta, C_d, C_d,
299 work, workspaceSizeEstimate, stream))
300
301 // Synchronize and measure timing
302 auto time = timer.seconds();
303
304 minTimeCUTENSOR = (minTimeCUTENSOR < time) ? minTimeCUTENSOR : time;
305 }
306
307 /*************************/
308
309 /**************************
310 * Set the algorithm to use -- with just-in-time compilation
311 ***************************/
312
313 cutensorPlanPreference_t planPrefJit;
314 HANDLE_ERROR(cutensorCreatePlanPreference(handle,
315 &planPrefJit,
316 algo,
317 CUTENSOR_JIT_MODE_DEFAULT));
318
319 /**********************
320 * Query workspace estimate
321 **********************/
322
323 uint64_t workspaceSizeEstimateJit = 0;
324 const cutensorWorksizePreference_t workspacePrefJit = CUTENSOR_WORKSPACE_DEFAULT;
325 HANDLE_ERROR(cutensorEstimateWorkspaceSize(handle,
326 desc,
327 planPrefJit,
328 workspacePrefJit,
329 &workspaceSizeEstimateJit));
330 // Allocate workspace
331 void *workJit = nullptr;
332 if (workspaceSizeEstimateJit > 0)
333 {
334 HANDLE_CUDA_ERROR(cudaMalloc(&workJit, workspaceSizeEstimateJit));
335 }
336
337 /**************************
338 * Create Contraction Plan -- with just-in-time compilation
339 **************************/
340
341 cutensorPlan_t planJit;
342 CPUTimer jitPlanTimer;
343 jitPlanTimer.start();
344 // This is where the kernel is actually compiled
345 HANDLE_ERROR(cutensorCreatePlan(handle,
346 &planJit,
347 desc,
348 planPrefJit,
349 workspaceSizeEstimateJit));
350 auto jitPlanTime = jitPlanTimer.seconds();
351
352 /**********************
353 * Execute the tensor contraction using the JIT compiled kernel
354 **********************/
355
356 double minTimeCUTENSORJit = 1e100;
357 for (int i=0; i < 3; ++i)
358 {
359 cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice);
360
361 // Set up timing
362 GPUTimer timer;
363 timer.start();
364
365 HANDLE_ERROR(cutensorContract(handle,
366 planJit,
367 (void*) &alpha, A_d, B_d,
368 (void*) &beta, C_d, C_d,
369 workJit, workspaceSizeEstimateJit, stream))
370
371 // Synchronize and measure timing
372 auto time = timer.seconds();
373
374 minTimeCUTENSORJit = (minTimeCUTENSORJit < time) ? minTimeCUTENSORJit : time;
375 }
376
377 /*************************/
378
379 float flops;
380 HANDLE_ERROR(cutensorOperationDescriptorGetAttribute(handle,
381 desc,
382 CUTENSOR_OPERATION_DESCRIPTOR_FLOPS,
383 (void*)&flops,
384 sizeof(flops)));
385 auto gflops = flops / 1e9;
386 auto gflopsPerSec = gflops / minTimeCUTENSOR;
387 auto gflopsPerSecJit = gflops / minTimeCUTENSORJit;
388
389 printf("cuTENSOR : %6.0f GFLOPs/s\n", gflopsPerSec);
390 printf("cuTENSOR JIT: %6.0f GFLOPs/s\n", gflopsPerSecJit);
391 printf("Speedup: %.1fx\n", gflopsPerSecJit / gflopsPerSec);
392 printf("JIT Compilation time: %.1f seconds ", jitPlanTime / 1e3);
393 if (readKernelCacheStatus == CUTENSOR_STATUS_SUCCESS)
394 printf("(Kernel cache file was read successfully; Compilation was not required)\n");
395 else
396 printf("\n");
397
398 // Write kernel cache to file
399 HANDLE_ERROR(cutensorWriteKernelCacheToFile(handle, "kernelCache.bin"))
400 printf("Kernel cache written to file. Will be read in next execution.\n");
401
402 HANDLE_ERROR(cutensorDestroy(handle));
403 HANDLE_ERROR(cutensorDestroyOperationDescriptor(desc));
404 HANDLE_ERROR(cutensorDestroyTensorDescriptor(descA));
405 HANDLE_ERROR(cutensorDestroyTensorDescriptor(descB));
406 HANDLE_ERROR(cutensorDestroyTensorDescriptor(descC));
407 HANDLE_CUDA_ERROR(cudaStreamDestroy(stream));
408 HANDLE_ERROR(cutensorDestroyPlanPreference(planPref));
409 HANDLE_ERROR(cutensorDestroyPlan(plan));
410 HANDLE_ERROR(cutensorDestroyPlanPreference(planPrefJit));
411 HANDLE_ERROR(cutensorDestroyPlan(planJit));
412
413 if (A) free(A);
414 if (B) free(B);
415 if (C) free(C);
416 if (A_d) cudaFree(A_d);
417 if (B_d) cudaFree(B_d);
418 if (C_d) cudaFree(C_d);
419 if (work) cudaFree(work);
420 if (workJit) cudaFree(workJit);
421
422 printf("Successful completion\n");
423 return 0;
424}