Just In Time (JIT) Compilation

This section introduces the Just-In-Time (JIT) Compilation feature. This feature allows users to compile specialized kernels, to maximize performance for a specific operation.

The complexity of a given contraction (e.g., number and order of modes, number of contracted modes), determines the size of its kernel search space (i.e., the set of candidate kernels that can be used to perform the contraction). As the complexity increases, the search space can become prohibitively large. Our pre-compiled kernels are carefully selected to perform well on a wide variety of different contractions. However, as the complexity of contractions increases, a just-in-time compiled kernel that is tailored for the given contraction can outperform a kernel from the fixed-sized set of pre-compiled kernels.

JIT compilation overcomes this limitation by creating a kernel that better utilizes the optimization opportunities that are applicable to a given contraction.

The cost of compiling a kernel typically amounts to 1-8 seconds (depending on the kernel and the host CPU); this cost occurs just once per kernel during the planning stage; the kernel can be reused by subsequent contractions (i.e., kernels are automatically cached once they are compiled).

All JIT compiled kernels get added in the kernel cache, which is accessible by the whole library (i.e., shared across library handles). We provide functions to read and write the kernel cache to disk, to avoid the cost of JIT compiling the same kernels when re-running a program.

The remainder of this section assumes familiarity with Getting Started.

Note

The JIT compilation feature is only supported for GPUs with compute capability greater or equal to 8.0 (Ampere or newer). Moreover, this feature is currently limited to tensor contractions.

Introductory Example

This subsection provides a basic overview of the API calls and features related to JIT compilation.

We begin by computing a contraction using the same steps as described in Getting Started, but with a different contraction example to emphasize the benefit of JIT compilation when the number of contracted modes increases. Then, we describe the necessary modifications to enable JIT compilation, and compare the performance of the pre-compiled and JIT-compiled kernels.

Our code computes the following operation (note that we now use numbers instead of letters to name each mode, since the number of modes exceeds the letters of the latin alphabet):

\[ \begin{align}\begin{aligned}C_{0,1,2,3,4,6,8,9,25,26,10,12,14,27,15,28,17,19,29,20,21,30,23,24} = \alpha A_{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24} B_{25,26,27,28,29,30,5,7,11,13,16,18,22}\\ + \beta C_{0,1,2,3,4,6,8,9,25,26,10,12,14,27,15,28,17,19,29,20,21,30,23,24}\end{aligned}\end{align} \]

All operands contain single-precision complex values, and computation is performed with emulated single-precision arithmetic (3XTF32). All modes have an extent of 2.

The steps to compute the above contraction are the same as the ones in Getting Started, and are listed below:

#include <chrono>
#include <complex>
#include <stdlib.h>
#include <stdio.h>
#include <unordered_map>
#include <vector>

#include <cuda_runtime.h>
#include <cutensor.h>

// Handle cuTENSOR errors
#define HANDLE_ERROR(x) {                                                                \
  const auto err = x;                                                                    \
  if( err != CUTENSOR_STATUS_SUCCESS )                                                   \
  { printf("Error: %s in line %d\n", cutensorGetErrorString(err), __LINE__); exit(-1); } \
};

// Handle CUDA errors
#define HANDLE_CUDA_ERROR(x) {                                                       \
  const auto err = x;                                                                \
  if( err != cudaSuccess )                                                           \
  { printf("Error: %s in line %d\n", cudaGetErrorString(err), __LINE__); exit(-1); } \
};

class CPUTimer
{
public:
    void start()
    {
        start_ = std::chrono::steady_clock::now();
    }

    double seconds()
    {
        end_ = std::chrono::steady_clock::now();
        elapsed_ = end_ - start_;
        //return in ms
        return elapsed_.count() * 1000;
    }

private:
    typedef std::chrono::steady_clock::time_point tp;
    tp start_;
    tp end_;
    std::chrono::duration<double> elapsed_;
};

struct GPUTimer
{
  GPUTimer()
  {
    cudaEventCreate(&start_);
    cudaEventCreate(&stop_);
    cudaEventRecord(start_, 0);
  }

  ~GPUTimer()
  {
    cudaEventDestroy(start_);
    cudaEventDestroy(stop_);
  }

  void start()
  {
    cudaEventRecord(start_, 0);
  }

  float seconds()
  {
    cudaEventRecord(stop_, 0);
    cudaEventSynchronize(stop_);
    float time;
    cudaEventElapsedTime(&time, start_, stop_);
    return time * 1e-3;
  }
  private:
  cudaEvent_t start_, stop_;
};

int main()
{
  typedef std::complex<float> TypeA;
  typedef std::complex<float> TypeB;
  typedef std::complex<float> TypeC;
  typedef std::complex<float> TypeScalar;

  auto alpha = TypeScalar(1.1, 0.0);
  auto beta  = TypeScalar(0.0, 0.0);

  // CUDA types
  cutensorDataType_t typeA = CUTENSOR_C_32F;
  cutensorDataType_t typeB = CUTENSOR_C_32F;
  cutensorDataType_t typeC = CUTENSOR_C_32F;
  cutensorComputeDescriptor_t descCompute = CUTENSOR_COMPUTE_DESC_3XTF32;


  /* ***************************** */

  // Create vector of modes
  std::vector<int> modeC{0,1,2,3,4,6,8,9,25,26,10,12,14,27,15,28,17,19,29,20,21,30,23,24};
  std::vector<int> modeA{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24};
  std::vector<int> modeB{25,26,27,28,29,30,5,7,11,13,16,18,22};
  int nmodeA = modeA.size();
  int nmodeB = modeB.size();
  int nmodeC = modeC.size();

  // Extents
  std::unordered_map<int, int64_t> extent;
  for (auto i = 0; i <= 30; i++)
      extent[i] = 2;

  // Create a vector of extents for each tensor
  std::vector<int64_t> extentC;
  for (auto mode : modeC)
      extentC.push_back(extent[mode]);
  std::vector<int64_t> extentA;
  for (auto mode : modeA)
      extentA.push_back(extent[mode]);
  std::vector<int64_t> extentB;
  for (auto mode : modeB)
      extentB.push_back(extent[mode]);

  /**********************
   * Allocating data
   **********************/

  // Number of elements of each tensor
  size_t elementsA = 1;
  for (auto mode : modeA)
      elementsA *= extent[mode];
  size_t elementsB = 1;
  for (auto mode : modeB)
      elementsB *= extent[mode];
  size_t elementsC = 1;
  for (auto mode : modeC)
      elementsC *= extent[mode];

  // Size in bytes
  size_t sizeA = sizeof(TypeA) * elementsA;
  size_t sizeB = sizeof(TypeB) * elementsB;
  size_t sizeC = sizeof(TypeC) * elementsC;
  printf("Total memory: %.2f GiB\n", (sizeA + sizeB + sizeC)/1024./1024./1024);

  // Allocate on device
  void *A_d, *B_d, *C_d;
  HANDLE_CUDA_ERROR(cudaMalloc((void**) &A_d, sizeA));
  HANDLE_CUDA_ERROR(cudaMalloc((void**) &B_d, sizeB));
  HANDLE_CUDA_ERROR(cudaMalloc((void**) &C_d, sizeC));

  // Allocate on host
  TypeA *A = (TypeA*) malloc(sizeof(TypeA) * elementsA);
  TypeB *B = (TypeB*) malloc(sizeof(TypeB) * elementsB);
  TypeC *C = (TypeC*) malloc(sizeof(TypeC) * elementsC);

  if (A == nullptr || B == nullptr || C == nullptr)
  {
      printf("Error: Host allocation of A, B, or C.\n");
      exit(-1);
  }

  /*******************
   * Initialize data
   *******************/

  for (int64_t i = 0; i < elementsA; i++)
      A[i] = (((float) rand())/RAND_MAX - 0.5)*100;
  for (int64_t i = 0; i < elementsB; i++)
      B[i] = (((float) rand())/RAND_MAX - 0.5)*100;
  for (int64_t i = 0; i < elementsC; i++)
      C[i] = (((float) rand())/RAND_MAX - 0.5)*100;

  // Copy to device
  HANDLE_CUDA_ERROR(cudaMemcpy(A_d, A, sizeA, cudaMemcpyHostToDevice));
  HANDLE_CUDA_ERROR(cudaMemcpy(B_d, B, sizeB, cudaMemcpyHostToDevice));
  HANDLE_CUDA_ERROR(cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice));

  const uint32_t kAlignment = 128; // Alignment of the global-memory device pointers (bytes)

  /*************************
   * cuTENSOR
   *************************/

  // Initialize cuTENSOR library
  cutensorHandle_t handle;
  HANDLE_ERROR(cutensorCreate(&handle));

  /**********************
   * Create Tensor Descriptors
   **********************/

  cutensorTensorDescriptor_t descA;
  HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
                                              &descA,
                                              nmodeA,
                                              extentA.data(),
                                              NULL,/*stride*/
                                              typeA, kAlignment));

  cutensorTensorDescriptor_t descB;
  HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
                                              &descB,
                                              nmodeB,
                                              extentB.data(),
                                              NULL,/*stride*/
                                              typeB, kAlignment));

  cutensorTensorDescriptor_t descC;
  HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
                                              &descC,
                                              nmodeC,
                                              extentC.data(),
                                              NULL,/*stride*/
                                              typeC, kAlignment));

  /*******************************
   * Create Contraction Descriptor
   *******************************/

  cutensorOperationDescriptor_t desc;
  HANDLE_ERROR(cutensorCreateContraction(handle,
                                         &desc,
                                         descA, modeA.data(), /* unary operator A*/CUTENSOR_OP_IDENTITY,
                                         descB, modeB.data(), /* unary operator B*/CUTENSOR_OP_IDENTITY,
                                         descC, modeC.data(), /* unary operator C*/CUTENSOR_OP_IDENTITY,
                                         descC, modeC.data(),
                                         descCompute));

  /**************************
  * Set the algorithm to use -- without just-in-time compilation
  ***************************/

  const cutensorAlgo_t algo = CUTENSOR_ALGO_DEFAULT;

  cutensorPlanPreference_t planPref;
  HANDLE_ERROR(cutensorCreatePlanPreference(handle,
                                            &planPref,
                                            algo,
                                            CUTENSOR_JIT_MODE_NONE));

  /**********************
   * Query workspace estimate
   **********************/

  uint64_t workspaceSizeEstimate = 0;
  const cutensorWorksizePreference_t workspacePref = CUTENSOR_WORKSPACE_DEFAULT;
  HANDLE_ERROR(cutensorEstimateWorkspaceSize(handle,
                                             desc,
                                             planPref,
                                             workspacePref,
                                             &workspaceSizeEstimate));
  // Allocate workspace
  void *work = nullptr;
  if (workspaceSizeEstimate > 0)
  {
      HANDLE_CUDA_ERROR(cudaMalloc(&work, workspaceSizeEstimate));
  }

  /**************************
   * Create Contraction Plan -- without just-in-time compilation
   **************************/

  cutensorPlan_t plan;
  HANDLE_ERROR(cutensorCreatePlan(handle,
                                  &plan,
                                  desc,
                                  planPref,
                                  workspaceSizeEstimate));

  /**********************
   * Execute the tensor contraction
   **********************/

  cudaStream_t stream;
  HANDLE_CUDA_ERROR(cudaStreamCreate(&stream));

  double minTimeCUTENSOR = 1e100;
  for (int i=0; i < 3; ++i)
  {
      cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice);

      // Set up timing
      GPUTimer timer;
      timer.start();

      HANDLE_ERROR(cutensorContract(handle,
                                    plan,
                                    (void*) &alpha, A_d, B_d,
                                    (void*) &beta,  C_d, C_d,
                                    work, workspaceSizeEstimate, stream))

      // Synchronize and measure timing
      auto time = timer.seconds();

      minTimeCUTENSOR = (minTimeCUTENSOR < time) ? minTimeCUTENSOR : time;
  }

  /*************************/

  float flops;
  HANDLE_ERROR(cutensorOperationDescriptorGetAttribute(handle,
                                                       desc,
                                                       CUTENSOR_OPERATION_DESCRIPTOR_FLOPS,
                                                       (void*)&flops,
                                                       sizeof(flops)));
  auto gflops = flops / 1e9;
  auto gflopsPerSec = gflops / minTimeCUTENSOR;

  printf("cuTENSOR    : %6.0f GFLOPs/s\n", gflopsPerSec);

  HANDLE_ERROR(cutensorDestroy(handle));
  HANDLE_ERROR(cutensorDestroyOperationDescriptor(desc));
  HANDLE_ERROR(cutensorDestroyTensorDescriptor(descA));
  HANDLE_ERROR(cutensorDestroyTensorDescriptor(descB));
  HANDLE_ERROR(cutensorDestroyTensorDescriptor(descC));
  HANDLE_CUDA_ERROR(cudaStreamDestroy(stream));
  HANDLE_ERROR(cutensorDestroyPlanPreference(planPref));
  HANDLE_ERROR(cutensorDestroyPlan(plan));

  if (A) free(A);
  if (B) free(B);
  if (C) free(C);
  if (A_d) cudaFree(A_d);
  if (B_d) cudaFree(B_d);
  if (C_d) cudaFree(C_d);
  if (work) cudaFree(work);

  printf("Successful completion\n");
  return 0;
}

All it takes to enable JIT compilation is to change the last argument of cutensorCreatePlanPreference() from CUTENSOR_JIT_MODE_NONE to CUTENSOR_JIT_MODE_DEFAULT–no further changes are required:

cutensorPlanPreference_t planPrefJit;
cutensorCreatePlanPreference(handle,
                             &planPrefJit,
                             algo,
                             CUTENSOR_JIT_MODE_DEFAULT);

The kernel is compiled during the call to cutensorCreatePlan(). This call is blocking and the compilation process can take up to a few seconds. Once the plan has been created, the kernel is compiled and stored in the kernel cache (see Reading and writing the kernel cache to disk) to be used via a call to cutensorContract().

The full working example is as follows:

#include <chrono>
#include <complex>
#include <stdlib.h>
#include <stdio.h>
#include <unordered_map>
#include <vector>

#include <cuda_runtime.h>
#include <cutensor.h>

// Handle cuTENSOR errors
#define HANDLE_ERROR(x) {                                                                \
  const auto err = x;                                                                    \
  if( err != CUTENSOR_STATUS_SUCCESS )                                                   \
  { printf("Error: %s in line %d\n", cutensorGetErrorString(err), __LINE__); exit(-1); } \
};

// Handle CUDA errors
#define HANDLE_CUDA_ERROR(x) {                                                       \
  const auto err = x;                                                                \
  if( err != cudaSuccess )                                                           \
  { printf("Error: %s in line %d\n", cudaGetErrorString(err), __LINE__); exit(-1); } \
};

class CPUTimer
{
public:
    void start()
    {
        start_ = std::chrono::steady_clock::now();
    }

    double seconds()
    {
        end_ = std::chrono::steady_clock::now();
        elapsed_ = end_ - start_;
        //return in ms
        return elapsed_.count() * 1000;
    }

private:
    typedef std::chrono::steady_clock::time_point tp;
    tp start_;
    tp end_;
    std::chrono::duration<double> elapsed_;
};

struct GPUTimer
{
  GPUTimer()
  {
    cudaEventCreate(&start_);
    cudaEventCreate(&stop_);
    cudaEventRecord(start_, 0);
  }

  ~GPUTimer()
  {
    cudaEventDestroy(start_);
    cudaEventDestroy(stop_);
  }

  void start()
  {
    cudaEventRecord(start_, 0);
  }

  float seconds()
  {
    cudaEventRecord(stop_, 0);
    cudaEventSynchronize(stop_);
    float time;
    cudaEventElapsedTime(&time, start_, stop_);
    return time * 1e-3;
  }
  private:
  cudaEvent_t start_, stop_;
};

int main()
{
  typedef std::complex<float> TypeA;
  typedef std::complex<float> TypeB;
  typedef std::complex<float> TypeC;
  typedef std::complex<float> TypeScalar;

  auto alpha = TypeScalar(1.1, 0.0);
  auto beta  = TypeScalar(0.0, 0.0);

  // CUDA types
  cutensorDataType_t typeA = CUTENSOR_C_32F;
  cutensorDataType_t typeB = CUTENSOR_C_32F;
  cutensorDataType_t typeC = CUTENSOR_C_32F;
  cutensorComputeDescriptor_t descCompute = CUTENSOR_COMPUTE_DESC_3XTF32;


  /* ***************************** */

  // Create vector of modes
  std::vector<int> modeC{0,1,2,3,4,6,8,9,25,26,10,12,14,27,15,28,17,19,29,20,21,30,23,24};
  std::vector<int> modeA{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24};
  std::vector<int> modeB{25,26,27,28,29,30,5,7,11,13,16,18,22};
  int nmodeA = modeA.size();
  int nmodeB = modeB.size();
  int nmodeC = modeC.size();

  // Extents
  std::unordered_map<int, int64_t> extent;
  for (auto i = 0; i <= 30; i++)
      extent[i] = 2;

  // Create a vector of extents for each tensor
  std::vector<int64_t> extentC;
  for (auto mode : modeC)
      extentC.push_back(extent[mode]);
  std::vector<int64_t> extentA;
  for (auto mode : modeA)
      extentA.push_back(extent[mode]);
  std::vector<int64_t> extentB;
  for (auto mode : modeB)
      extentB.push_back(extent[mode]);

  /**********************
   * Allocating data
   **********************/

  // Number of elements of each tensor
  size_t elementsA = 1;
  for (auto mode : modeA)
      elementsA *= extent[mode];
  size_t elementsB = 1;
  for (auto mode : modeB)
      elementsB *= extent[mode];
  size_t elementsC = 1;
  for (auto mode : modeC)
      elementsC *= extent[mode];

  // Size in bytes
  size_t sizeA = sizeof(TypeA) * elementsA;
  size_t sizeB = sizeof(TypeB) * elementsB;
  size_t sizeC = sizeof(TypeC) * elementsC;
  printf("Total memory: %.2f GiB\n", (sizeA + sizeB + sizeC)/1024./1024./1024);

  // Allocate on device
  void *A_d, *B_d, *C_d;
  HANDLE_CUDA_ERROR(cudaMalloc((void**) &A_d, sizeA));
  HANDLE_CUDA_ERROR(cudaMalloc((void**) &B_d, sizeB));
  HANDLE_CUDA_ERROR(cudaMalloc((void**) &C_d, sizeC));

  // Allocate on host
  TypeA *A = (TypeA*) malloc(sizeof(TypeA) * elementsA);
  TypeB *B = (TypeB*) malloc(sizeof(TypeB) * elementsB);
  TypeC *C = (TypeC*) malloc(sizeof(TypeC) * elementsC);

  if (A == nullptr || B == nullptr || C == nullptr)
  {
      printf("Error: Host allocation of A, B, or C.\n");
      exit(-1);
  }

  /*******************
   * Initialize data
   *******************/

  for (int64_t i = 0; i < elementsA; i++)
      A[i] = (((float) rand())/RAND_MAX - 0.5)*100;
  for (int64_t i = 0; i < elementsB; i++)
      B[i] = (((float) rand())/RAND_MAX - 0.5)*100;
  for (int64_t i = 0; i < elementsC; i++)
      C[i] = (((float) rand())/RAND_MAX - 0.5)*100;

  // Copy to device
  HANDLE_CUDA_ERROR(cudaMemcpy(A_d, A, sizeA, cudaMemcpyHostToDevice));
  HANDLE_CUDA_ERROR(cudaMemcpy(B_d, B, sizeB, cudaMemcpyHostToDevice));
  HANDLE_CUDA_ERROR(cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice));

  const uint32_t kAlignment = 128; // Alignment of the global-memory device pointers (bytes)

  /*************************
   * cuTENSOR
   *************************/

  // Initialize cuTENSOR library
  cutensorHandle_t handle;
  HANDLE_ERROR(cutensorCreate(&handle));

  /**********************
   * Create Tensor Descriptors
   **********************/

  cutensorTensorDescriptor_t descA;
  HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
                                              &descA,
                                              nmodeA,
                                              extentA.data(),
                                              NULL,/*stride*/
                                              typeA, kAlignment));

  cutensorTensorDescriptor_t descB;
  HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
                                              &descB,
                                              nmodeB,
                                              extentB.data(),
                                              NULL,/*stride*/
                                              typeB, kAlignment));

  cutensorTensorDescriptor_t descC;
  HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
                                              &descC,
                                              nmodeC,
                                              extentC.data(),
                                              NULL,/*stride*/
                                              typeC, kAlignment));

  /*******************************
   * Create Contraction Descriptor
   *******************************/

  cutensorOperationDescriptor_t desc;
  HANDLE_ERROR(cutensorCreateContraction(handle,
                                         &desc,
                                         descA, modeA.data(), /* unary operator A*/CUTENSOR_OP_IDENTITY,
                                         descB, modeB.data(), /* unary operator B*/CUTENSOR_OP_IDENTITY,
                                         descC, modeC.data(), /* unary operator C*/CUTENSOR_OP_IDENTITY,
                                         descC, modeC.data(),
                                         descCompute));

  /**************************
  * Set the algorithm to use -- without just-in-time compilation
  ***************************/

  const cutensorAlgo_t algo = CUTENSOR_ALGO_DEFAULT;

  cutensorPlanPreference_t planPref;
  HANDLE_ERROR(cutensorCreatePlanPreference(handle,
                                            &planPref,
                                            algo,
                                            CUTENSOR_JIT_MODE_NONE));

  /**********************
   * Query workspace estimate
   **********************/

  uint64_t workspaceSizeEstimate = 0;
  const cutensorWorksizePreference_t workspacePref = CUTENSOR_WORKSPACE_DEFAULT;
  HANDLE_ERROR(cutensorEstimateWorkspaceSize(handle,
                                             desc,
                                             planPref,
                                             workspacePref,
                                             &workspaceSizeEstimate));
  // Allocate workspace
  void *work = nullptr;
  if (workspaceSizeEstimate > 0)
  {
      HANDLE_CUDA_ERROR(cudaMalloc(&work, workspaceSizeEstimate));
  }

  /**************************
   * Create Contraction Plan -- without just-in-time compilation
   **************************/

  cutensorPlan_t plan;
  HANDLE_ERROR(cutensorCreatePlan(handle,
                                  &plan,
                                  desc,
                                  planPref,
                                  workspaceSizeEstimate));

  /**********************
   * Execute the tensor contraction
   **********************/

  cudaStream_t stream;
  HANDLE_CUDA_ERROR(cudaStreamCreate(&stream));

  double minTimeCUTENSOR = 1e100;
  for (int i=0; i < 3; ++i)
  {
      cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice);

      // Set up timing
      GPUTimer timer;
      timer.start();

      HANDLE_ERROR(cutensorContract(handle,
                                    plan,
                                    (void*) &alpha, A_d, B_d,
                                    (void*) &beta,  C_d, C_d,
                                    work, workspaceSizeEstimate, stream))

      // Synchronize and measure timing
      auto time = timer.seconds();

      minTimeCUTENSOR = (minTimeCUTENSOR < time) ? minTimeCUTENSOR : time;
  }

  /*************************/

  /**************************
  * Set the algorithm to use -- with just-in-time compilation
  ***************************/

  cutensorPlanPreference_t planPrefJit;
  HANDLE_ERROR(cutensorCreatePlanPreference(handle,
                                            &planPrefJit,
                                            algo,
                                            CUTENSOR_JIT_MODE_DEFAULT));

  /**********************
   * Query workspace estimate
   **********************/

  uint64_t workspaceSizeEstimateJit = 0;
  const cutensorWorksizePreference_t workspacePrefJit = CUTENSOR_WORKSPACE_DEFAULT;
  HANDLE_ERROR(cutensorEstimateWorkspaceSize(handle,
                                             desc,
                                             planPrefJit,
                                             workspacePrefJit,
                                             &workspaceSizeEstimateJit));
  // Allocate workspace
  void *workJit = nullptr;
  if (workspaceSizeEstimateJit > 0)
  {
      HANDLE_CUDA_ERROR(cudaMalloc(&workJit, workspaceSizeEstimateJit));
  }

  /**************************
   * Create Contraction Plan -- with just-in-time compilation
   **************************/

  cutensorPlan_t planJit;
  CPUTimer jitPlanTimer;
  jitPlanTimer.start();
  // This is where the kernel is actually compiled
  HANDLE_ERROR(cutensorCreatePlan(handle,
                                  &planJit,
                                  desc,
                                  planPrefJit,
                                  workspaceSizeEstimateJit));
  auto jitPlanTime = jitPlanTimer.seconds();

  /**********************
   * Execute the tensor contraction using the JIT compiled kernel
   **********************/

  double minTimeCUTENSORJit = 1e100;
  for (int i=0; i < 3; ++i)
  {
      cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice);

      // Set up timing
      GPUTimer timer;
      timer.start();

      HANDLE_ERROR(cutensorContract(handle,
                                    planJit,
                                    (void*) &alpha, A_d, B_d,
                                    (void*) &beta,  C_d, C_d,
                                    workJit, workspaceSizeEstimateJit, stream))

      // Synchronize and measure timing
      auto time = timer.seconds();

      minTimeCUTENSORJit = (minTimeCUTENSORJit < time) ? minTimeCUTENSORJit : time;
  }

  /*************************/

  float flops;
  HANDLE_ERROR(cutensorOperationDescriptorGetAttribute(handle,
                                                       desc,
                                                       CUTENSOR_OPERATION_DESCRIPTOR_FLOPS,
                                                       (void*)&flops,
                                                       sizeof(flops)));
  auto gflops = flops / 1e9;
  auto gflopsPerSec = gflops / minTimeCUTENSOR;
  auto gflopsPerSecJit = gflops / minTimeCUTENSORJit;

  printf("cuTENSOR    : %6.0f GFLOPs/s\n", gflopsPerSec);
  printf("cuTENSOR JIT: %6.0f GFLOPs/s\n", gflopsPerSecJit);
  printf("Speedup: %.1fx\n", gflopsPerSecJit / gflopsPerSec);
  printf("JIT Compilation time: %.1f seconds\n", jitPlanTime / 1e3);

  HANDLE_ERROR(cutensorDestroy(handle));
  HANDLE_ERROR(cutensorDestroyOperationDescriptor(desc));
  HANDLE_ERROR(cutensorDestroyTensorDescriptor(descA));
  HANDLE_ERROR(cutensorDestroyTensorDescriptor(descB));
  HANDLE_ERROR(cutensorDestroyTensorDescriptor(descC));
  HANDLE_CUDA_ERROR(cudaStreamDestroy(stream));
  HANDLE_ERROR(cutensorDestroyPlanPreference(planPref));
  HANDLE_ERROR(cutensorDestroyPlan(plan));
  HANDLE_ERROR(cutensorDestroyPlanPreference(planPrefJit));
  HANDLE_ERROR(cutensorDestroyPlan(planJit));

  if (A) free(A);
  if (B) free(B);
  if (C) free(C);
  if (A_d) cudaFree(A_d);
  if (B_d) cudaFree(B_d);
  if (C_d) cudaFree(C_d);
  if (work) cudaFree(work);
  if (workJit) cudaFree(workJit);

  printf("Successful completion\n");
  return 0;
}

Below is the output of the above program on an NVIDIA H100 PCIe GPU:

cuTENSOR    :    774 GFLOPs/s
cuTENSOR JIT:   5374 GFLOPs/s
Speedup: 6.9x
JIT Compilation time: 8.3 seconds

This concludes the introductory example.

Reading and writing the kernel cache to disk

JIT compilation can take a significant amount of time, especially if applied to tens or hundreds of different contractions. To amortize this overhead, we provide functions to write the kernel cache to disk once all plans have been created. This way, compilation cost occurs only once for each plan across different executions of the program.

Note

A kernel cache file stores information about the version of cuTENSOR that was used (CUTENSOR_VERSION), the version of CUDA on the system (CUDA_VERSION), and the model of the GPU (GPU_MODEL). For a kernel cache file to be read successfully, all three of these values must match exactly on the target system.

To write a kernel cache to file use the cutensorWriteKernelCacheToFile() function, after all plans that use JIT compilation have been created.

cutensorWriteKernelCacheToFile(handle, "kernelCache.bin")

To read the file and load the kernels into a running instance of cuTENSOR, simply use the cutensorReadKernelCacheFromFile() function:

cutensorReadKernelCacheFromFile(handle, "kernelCache.bin")

Note

After reading a kernel cache from file, users must still enable JIT compilation during cutensorCreatePlanPreference() (by providing the CUTENSOR_JIT_MODE_DEFAULT argument) for the contractions that ought to use the previously-JIT compiled kernels.

Below, we repeat the Introductory Example but in line 188 we check to see if a kernel cache file can be read, and in line 399 we write the kernel cache to a file. On the second execution of the below sample, the kernel cache is read and compilation is avoided.

  1#include <chrono>
  2#include <complex>
  3#include <stdlib.h>
  4#include <stdio.h>
  5#include <unordered_map>
  6#include <vector>
  7
  8#include <cuda_runtime.h>
  9#include <cutensor.h>
 10
 11// Handle cuTENSOR errors
 12#define HANDLE_ERROR(x) {                                                                \
 13  const auto err = x;                                                                    \
 14  if( err != CUTENSOR_STATUS_SUCCESS )                                                   \
 15  { printf("Error: %s in line %d\n", cutensorGetErrorString(err), __LINE__); exit(-1); } \
 16};
 17
 18// Handle CUDA errors
 19#define HANDLE_CUDA_ERROR(x) {                                                       \
 20  const auto err = x;                                                                \
 21  if( err != cudaSuccess )                                                           \
 22  { printf("Error: %s in line %d\n", cudaGetErrorString(err), __LINE__); exit(-1); } \
 23};
 24
 25class CPUTimer
 26{
 27public:
 28    void start()
 29    {
 30        start_ = std::chrono::steady_clock::now();
 31    }
 32
 33    double seconds()
 34    {
 35        end_ = std::chrono::steady_clock::now();
 36        elapsed_ = end_ - start_;
 37        //return in ms
 38        return elapsed_.count() * 1000;
 39    }
 40
 41private:
 42    typedef std::chrono::steady_clock::time_point tp;
 43    tp start_;
 44    tp end_;
 45    std::chrono::duration<double> elapsed_;
 46};
 47
 48struct GPUTimer
 49{
 50  GPUTimer()
 51  {
 52    cudaEventCreate(&start_);
 53    cudaEventCreate(&stop_);
 54    cudaEventRecord(start_, 0);
 55  }
 56
 57  ~GPUTimer()
 58  {
 59    cudaEventDestroy(start_);
 60    cudaEventDestroy(stop_);
 61  }
 62
 63  void start()
 64  {
 65    cudaEventRecord(start_, 0);
 66  }
 67
 68  float seconds()
 69  {
 70    cudaEventRecord(stop_, 0);
 71    cudaEventSynchronize(stop_);
 72    float time;
 73    cudaEventElapsedTime(&time, start_, stop_);
 74    return time * 1e-3;
 75  }
 76  private:
 77  cudaEvent_t start_, stop_;
 78};
 79
 80int main()
 81{
 82  typedef std::complex<float> TypeA;
 83  typedef std::complex<float> TypeB;
 84  typedef std::complex<float> TypeC;
 85  typedef std::complex<float> TypeScalar;
 86
 87  auto alpha = TypeScalar(1.1, 0.0);
 88  auto beta  = TypeScalar(0.0, 0.0);
 89
 90  // CUDA types
 91  cutensorDataType_t typeA = CUTENSOR_C_32F;
 92  cutensorDataType_t typeB = CUTENSOR_C_32F;
 93  cutensorDataType_t typeC = CUTENSOR_C_32F;
 94  cutensorComputeDescriptor_t descCompute = CUTENSOR_COMPUTE_DESC_3XTF32;
 95
 96
 97  /* ***************************** */
 98
 99  // Create vector of modes
100  std::vector<int> modeC{0,1,2,3,4,6,8,9,25,26,10,12,14,27,15,28,17,19,29,20,21,30,23,24};
101  std::vector<int> modeA{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24};
102  std::vector<int> modeB{25,26,27,28,29,30,5,7,11,13,16,18,22};
103  int nmodeA = modeA.size();
104  int nmodeB = modeB.size();
105  int nmodeC = modeC.size();
106
107  // Extents
108  std::unordered_map<int, int64_t> extent;
109  for (auto i = 0; i <= 30; i++)
110      extent[i] = 2;
111
112  // Create a vector of extents for each tensor
113  std::vector<int64_t> extentC;
114  for (auto mode : modeC)
115      extentC.push_back(extent[mode]);
116  std::vector<int64_t> extentA;
117  for (auto mode : modeA)
118      extentA.push_back(extent[mode]);
119  std::vector<int64_t> extentB;
120  for (auto mode : modeB)
121      extentB.push_back(extent[mode]);
122
123  /**********************
124   * Allocating data
125   **********************/
126
127  // Number of elements of each tensor
128  size_t elementsA = 1;
129  for (auto mode : modeA)
130      elementsA *= extent[mode];
131  size_t elementsB = 1;
132  for (auto mode : modeB)
133      elementsB *= extent[mode];
134  size_t elementsC = 1;
135  for (auto mode : modeC)
136      elementsC *= extent[mode];
137
138  // Size in bytes
139  size_t sizeA = sizeof(TypeA) * elementsA;
140  size_t sizeB = sizeof(TypeB) * elementsB;
141  size_t sizeC = sizeof(TypeC) * elementsC;
142  printf("Total memory: %.2f GiB\n", (sizeA + sizeB + sizeC)/1024./1024./1024);
143
144  // Allocate on device
145  void *A_d, *B_d, *C_d;
146  HANDLE_CUDA_ERROR(cudaMalloc((void**) &A_d, sizeA));
147  HANDLE_CUDA_ERROR(cudaMalloc((void**) &B_d, sizeB));
148  HANDLE_CUDA_ERROR(cudaMalloc((void**) &C_d, sizeC));
149
150  // Allocate on host
151  TypeA *A = (TypeA*) malloc(sizeof(TypeA) * elementsA);
152  TypeB *B = (TypeB*) malloc(sizeof(TypeB) * elementsB);
153  TypeC *C = (TypeC*) malloc(sizeof(TypeC) * elementsC);
154
155  if (A == nullptr || B == nullptr || C == nullptr)
156  {
157      printf("Error: Host allocation of A, B, or C.\n");
158      exit(-1);
159  }
160
161  /*******************
162   * Initialize data
163   *******************/
164
165  for (int64_t i = 0; i < elementsA; i++)
166      A[i] = (((float) rand())/RAND_MAX - 0.5)*100;
167  for (int64_t i = 0; i < elementsB; i++)
168      B[i] = (((float) rand())/RAND_MAX - 0.5)*100;
169  for (int64_t i = 0; i < elementsC; i++)
170      C[i] = (((float) rand())/RAND_MAX - 0.5)*100;
171
172  // Copy to device
173  HANDLE_CUDA_ERROR(cudaMemcpy(A_d, A, sizeA, cudaMemcpyHostToDevice));
174  HANDLE_CUDA_ERROR(cudaMemcpy(B_d, B, sizeB, cudaMemcpyHostToDevice));
175  HANDLE_CUDA_ERROR(cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice));
176
177  const uint32_t kAlignment = 128; // Alignment of the global-memory device pointers (bytes)
178
179  /*************************
180   * cuTENSOR
181   *************************/
182
183  // Initialize cuTENSOR library
184  cutensorHandle_t handle;
185  HANDLE_ERROR(cutensorCreate(&handle));
186
187  // Read kernel cache from file (if the file was generated by a prior execution)
188  auto readKernelCacheStatus = cutensorReadKernelCacheFromFile(handle, "kernelCache.bin");
189
190  if (readKernelCacheStatus == CUTENSOR_STATUS_IO_ERROR)
191      printf("No kernel cache found. It will be generated before the end of this execution.\n");
192  else if (readKernelCacheStatus == CUTENSOR_STATUS_SUCCESS)
193      printf("Kernel cache found and read successfully.\n");
194  else
195      HANDLE_ERROR(readKernelCacheStatus);
196
197  /**********************
198   * Create Tensor Descriptors
199   **********************/
200
201  cutensorTensorDescriptor_t descA;
202  HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
203                                              &descA,
204                                              nmodeA,
205                                              extentA.data(),
206                                              NULL,/*stride*/
207                                              typeA, kAlignment));
208
209  cutensorTensorDescriptor_t descB;
210  HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
211                                              &descB,
212                                              nmodeB,
213                                              extentB.data(),
214                                              NULL,/*stride*/
215                                              typeB, kAlignment));
216
217  cutensorTensorDescriptor_t descC;
218  HANDLE_ERROR(cutensorCreateTensorDescriptor(handle,
219                                              &descC,
220                                              nmodeC,
221                                              extentC.data(),
222                                              NULL,/*stride*/
223                                              typeC, kAlignment));
224
225  /*******************************
226   * Create Contraction Descriptor
227   *******************************/
228
229  cutensorOperationDescriptor_t desc;
230  HANDLE_ERROR(cutensorCreateContraction(handle,
231                                         &desc,
232                                         descA, modeA.data(), /* unary operator A*/CUTENSOR_OP_IDENTITY,
233                                         descB, modeB.data(), /* unary operator B*/CUTENSOR_OP_IDENTITY,
234                                         descC, modeC.data(), /* unary operator C*/CUTENSOR_OP_IDENTITY,
235                                         descC, modeC.data(),
236                                         descCompute));
237
238  /**************************
239  * Set the algorithm to use -- without just-in-time compilation
240  ***************************/
241
242  const cutensorAlgo_t algo = CUTENSOR_ALGO_DEFAULT;
243
244  cutensorPlanPreference_t planPref;
245  HANDLE_ERROR(cutensorCreatePlanPreference(handle,
246                                            &planPref,
247                                            algo,
248                                            CUTENSOR_JIT_MODE_NONE));
249
250  /**********************
251   * Query workspace estimate
252   **********************/
253
254  uint64_t workspaceSizeEstimate = 0;
255  const cutensorWorksizePreference_t workspacePref = CUTENSOR_WORKSPACE_DEFAULT;
256  HANDLE_ERROR(cutensorEstimateWorkspaceSize(handle,
257                                             desc,
258                                             planPref,
259                                             workspacePref,
260                                             &workspaceSizeEstimate));
261  // Allocate workspace
262  void *work = nullptr;
263  if (workspaceSizeEstimate > 0)
264  {
265      HANDLE_CUDA_ERROR(cudaMalloc(&work, workspaceSizeEstimate));
266  }
267
268  /**************************
269   * Create Contraction Plan -- without just-in-time compilation
270   **************************/
271
272  cutensorPlan_t plan;
273  HANDLE_ERROR(cutensorCreatePlan(handle,
274                                  &plan,
275                                  desc,
276                                  planPref,
277                                  workspaceSizeEstimate));
278
279  /**********************
280   * Execute the tensor contraction
281   **********************/
282
283  cudaStream_t stream;
284  HANDLE_CUDA_ERROR(cudaStreamCreate(&stream));
285
286  double minTimeCUTENSOR = 1e100;
287  for (int i=0; i < 3; ++i)
288  {
289      cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice);
290
291      // Set up timing
292      GPUTimer timer;
293      timer.start();
294
295      HANDLE_ERROR(cutensorContract(handle,
296                                    plan,
297                                    (void*) &alpha, A_d, B_d,
298                                    (void*) &beta,  C_d, C_d,
299                                    work, workspaceSizeEstimate, stream))
300
301      // Synchronize and measure timing
302      auto time = timer.seconds();
303
304      minTimeCUTENSOR = (minTimeCUTENSOR < time) ? minTimeCUTENSOR : time;
305  }
306
307  /*************************/
308
309  /**************************
310  * Set the algorithm to use -- with just-in-time compilation
311  ***************************/
312
313  cutensorPlanPreference_t planPrefJit;
314  HANDLE_ERROR(cutensorCreatePlanPreference(handle,
315                                            &planPrefJit,
316                                            algo,
317                                            CUTENSOR_JIT_MODE_DEFAULT));
318
319  /**********************
320   * Query workspace estimate
321   **********************/
322
323  uint64_t workspaceSizeEstimateJit = 0;
324  const cutensorWorksizePreference_t workspacePrefJit = CUTENSOR_WORKSPACE_DEFAULT;
325  HANDLE_ERROR(cutensorEstimateWorkspaceSize(handle,
326                                             desc,
327                                             planPrefJit,
328                                             workspacePrefJit,
329                                             &workspaceSizeEstimateJit));
330  // Allocate workspace
331  void *workJit = nullptr;
332  if (workspaceSizeEstimateJit > 0)
333  {
334      HANDLE_CUDA_ERROR(cudaMalloc(&workJit, workspaceSizeEstimateJit));
335  }
336
337  /**************************
338   * Create Contraction Plan -- with just-in-time compilation
339   **************************/
340
341  cutensorPlan_t planJit;
342  CPUTimer jitPlanTimer;
343  jitPlanTimer.start();
344  // This is where the kernel is actually compiled
345  HANDLE_ERROR(cutensorCreatePlan(handle,
346                                  &planJit,
347                                  desc,
348                                  planPrefJit,
349                                  workspaceSizeEstimateJit));
350  auto jitPlanTime = jitPlanTimer.seconds();
351
352  /**********************
353   * Execute the tensor contraction using the JIT compiled kernel
354   **********************/
355
356  double minTimeCUTENSORJit = 1e100;
357  for (int i=0; i < 3; ++i)
358  {
359      cudaMemcpy(C_d, C, sizeC, cudaMemcpyHostToDevice);
360
361      // Set up timing
362      GPUTimer timer;
363      timer.start();
364
365      HANDLE_ERROR(cutensorContract(handle,
366                                    planJit,
367                                    (void*) &alpha, A_d, B_d,
368                                    (void*) &beta,  C_d, C_d,
369                                    workJit, workspaceSizeEstimateJit, stream))
370
371      // Synchronize and measure timing
372      auto time = timer.seconds();
373
374      minTimeCUTENSORJit = (minTimeCUTENSORJit < time) ? minTimeCUTENSORJit : time;
375  }
376
377  /*************************/
378
379  float flops;
380  HANDLE_ERROR(cutensorOperationDescriptorGetAttribute(handle,
381                                                       desc,
382                                                       CUTENSOR_OPERATION_DESCRIPTOR_FLOPS,
383                                                       (void*)&flops,
384                                                       sizeof(flops)));
385  auto gflops = flops / 1e9;
386  auto gflopsPerSec = gflops / minTimeCUTENSOR;
387  auto gflopsPerSecJit = gflops / minTimeCUTENSORJit;
388
389  printf("cuTENSOR    : %6.0f GFLOPs/s\n", gflopsPerSec);
390  printf("cuTENSOR JIT: %6.0f GFLOPs/s\n", gflopsPerSecJit);
391  printf("Speedup: %.1fx\n", gflopsPerSecJit / gflopsPerSec);
392  printf("JIT Compilation time: %.1f seconds ", jitPlanTime / 1e3);
393  if (readKernelCacheStatus == CUTENSOR_STATUS_SUCCESS)
394      printf("(Kernel cache file was read successfully; Compilation was not required)\n");
395  else
396      printf("\n");
397
398  // Write kernel cache to file
399  HANDLE_ERROR(cutensorWriteKernelCacheToFile(handle, "kernelCache.bin"))
400  printf("Kernel cache written to file. Will be read in next execution.\n");
401
402  HANDLE_ERROR(cutensorDestroy(handle));
403  HANDLE_ERROR(cutensorDestroyOperationDescriptor(desc));
404  HANDLE_ERROR(cutensorDestroyTensorDescriptor(descA));
405  HANDLE_ERROR(cutensorDestroyTensorDescriptor(descB));
406  HANDLE_ERROR(cutensorDestroyTensorDescriptor(descC));
407  HANDLE_CUDA_ERROR(cudaStreamDestroy(stream));
408  HANDLE_ERROR(cutensorDestroyPlanPreference(planPref));
409  HANDLE_ERROR(cutensorDestroyPlan(plan));
410  HANDLE_ERROR(cutensorDestroyPlanPreference(planPrefJit));
411  HANDLE_ERROR(cutensorDestroyPlan(planJit));
412
413  if (A) free(A);
414  if (B) free(B);
415  if (C) free(C);
416  if (A_d) cudaFree(A_d);
417  if (B_d) cudaFree(B_d);
418  if (C_d) cudaFree(C_d);
419  if (work) cudaFree(work);
420  if (workJit) cudaFree(workJit);
421
422  printf("Successful completion\n");
423  return 0;
424}