Performing tensor SVD using cuTensorNet adopts a very similar workflow as QR example. Here, we highlight the notable differences between the two APIs. The full code can be found in the NVIDIA/cuQuantum repository (here).
Define SVD decomposition¶
As with QR decomposition, we first define the SVD decomposition to perform with the data type, modes partition, and the extents.
103 /******************************************************
104 * Tensor SVD: T_{i,j,m,n} -> U_{i,x,m} S_{x} V_{n,x,j}
105 *******************************************************/
106
107 typedef float floatType;
108 cudaDataType_t typeData = CUDA_R_32F;
109
110 // Create vector of modes
111 int32_t sharedMode = 'x';
112
113 std::vector<int32_t> modesT{'i','j','m','n'}; // input
114 std::vector<int32_t> modesU{'i', sharedMode,'m'};
115 std::vector<int32_t> modesV{'n', sharedMode,'j'}; // SVD output
116
117 // Extents
118 std::unordered_map<int32_t, int64_t> extentMap;
119 extentMap['i'] = 16;
120 extentMap['j'] = 16;
121 extentMap['m'] = 16;
122 extentMap['n'] = 16;
123
124 int64_t rowExtent = computeCombinedExtent(extentMap, modesU);
125 int64_t colExtent = computeCombinedExtent(extentMap, modesV);
126 // cuTensorNet tensor SVD operates in reduced mode expecting k <= min(m, n)
127 int64_t fullSharedExtent = rowExtent <= colExtent? rowExtent: colExtent;
128 const int64_t maxExtent = fullSharedExtent / 2; //fix extent truncation with half of the singular values trimmed out
129 extentMap[sharedMode] = maxExtent;
130
131 // Create a vector of extents for each tensor
132 std::vector<int64_t> extentT;
133 for (auto mode : modesT)
134 extentT.push_back(extentMap[mode]);
135 std::vector<int64_t> extentU;
136 for (auto mode : modesU)
137 extentU.push_back(extentMap[mode]);
138 std::vector<int64_t> extentV;
139 for (auto mode : modesV)
140 extentV.push_back(extentMap[mode]);
Note
To perform fixed extent truncation, we directly set maxExtent
to half of the full extent corresponding to exact SVD.
Setup SVD truncation parameters¶
Once the SVD decomposition is defined, we can follow the same workflow as QR example for data allocation and tensor descriptor initialization.
Before querying workspace, we can choose different SVD options in cutensornetTensorSVDConfig_t
.
Meanwhile, we can create cutensornetTensorSVDInfo_t
to keep track of runtime truncation information.
227 /**********************************************
228 * Setup SVD algorithm and truncation parameters
229 ***********************************************/
230
231 cutensornetTensorSVDConfig_t svdConfig;
232 HANDLE_ERROR( cutensornetCreateTensorSVDConfig(handle, &svdConfig) );
233
234 // set up truncation parameters
235 double absCutoff = 1e-2;
236 HANDLE_ERROR( cutensornetTensorSVDConfigSetAttribute(handle,
237 svdConfig,
238 CUTENSORNET_TENSOR_SVD_CONFIG_ABS_CUTOFF,
239 &absCutoff,
240 sizeof(absCutoff)) );
241 double relCutoff = 4e-2;
242 HANDLE_ERROR( cutensornetTensorSVDConfigSetAttribute(handle,
243 svdConfig,
244 CUTENSORNET_TENSOR_SVD_CONFIG_REL_CUTOFF,
245 &relCutoff,
246 sizeof(relCutoff)) );
247
248 // optional: choose gesvdj algorithm with customized parameters. Default is gesvd.
249 cutensornetTensorSVDAlgo_t svdAlgo = CUTENSORNET_TENSOR_SVD_ALGO_GESVDJ;
250 HANDLE_ERROR( cutensornetTensorSVDConfigSetAttribute(handle,
251 svdConfig,
252 CUTENSORNET_TENSOR_SVD_CONFIG_ALGO,
253 &svdAlgo,
254 sizeof(svdAlgo)) );
255 cutensornetGesvdjParams_t gesvdjParams{/*tol=*/1e-12, /*maxSweeps=*/80};
256 HANDLE_ERROR( cutensornetTensorSVDConfigSetAttribute(handle,
257 svdConfig,
258 CUTENSORNET_TENSOR_SVD_CONFIG_ALGO_PARAMS,
259 &gesvdjParams,
260 sizeof(gesvdjParams)) );
261 printf("Set up SVDConfig to use GESVDJ algorithm with truncation\n");
262
263 /********************************************************
264 * Create SVDInfo to record runtime SVD truncation details
265 *********************************************************/
266
267 cutensornetTensorSVDInfo_t svdInfo;
268 HANDLE_ERROR( cutensornetCreateTensorSVDInfo(handle, &svdInfo)) ;
Execution¶
Next, we can query and allocate the workspace with cutensornetWorkspaceComputeSVDSizes()
, which is very similar to its QR counterpart.
At this stage, we can perform the SVD decomposition by calling cutensornetTensorSVD()
.
314 /**********
315 * Execution
316 ***********/
317
318 GPUTimer timer{stream};
319 double minTimeCUTENSOR = 1e100;
320 const int numRuns = 3; // to get stable perf results
321 for (int i=0; i < numRuns; ++i)
322 {
323 // restore output
324 cudaMemsetAsync(D_U, 0, sizeU, stream);
325 cudaMemsetAsync(D_S, 0, sizeS, stream);
326 cudaMemsetAsync(D_V, 0, sizeV, stream);
327 cudaDeviceSynchronize();
328
329 // With value-based truncation, `cutensornetTensorSVD` can potentially update the shared extent in descTensorU/V.
330 // We here restore descTensorU/V to the original problem.
331 HANDLE_ERROR( cutensornetDestroyTensorDescriptor(descTensorU) );
332 HANDLE_ERROR( cutensornetDestroyTensorDescriptor(descTensorV) );
333 HANDLE_ERROR( cutensornetCreateTensorDescriptor(handle, numModesU, extentU.data(), strides, modesU.data(), typeData, &descTensorU) );
334 HANDLE_ERROR( cutensornetCreateTensorDescriptor(handle, numModesV, extentV.data(), strides, modesV.data(), typeData, &descTensorV) );
335
336 timer.start();
337 HANDLE_ERROR( cutensornetTensorSVD(handle,
338 descTensorIn, D_T,
339 descTensorU, D_U,
340 D_S,
341 descTensorV, D_V,
342 svdConfig,
343 svdInfo,
344 workDesc,
345 stream) );
346 // Synchronize and measure timing
347 auto time = timer.seconds();
348 minTimeCUTENSOR = (minTimeCUTENSOR < time) ? minTimeCUTENSOR : time;
349 }
350
351 printf("Performing SVD\n");
352
353 HANDLE_CUDA_ERROR( cudaMemcpyAsync(U, D_U, sizeU, cudaMemcpyDeviceToHost) );
354 HANDLE_CUDA_ERROR( cudaMemcpyAsync(S, D_S, sizeS, cudaMemcpyDeviceToHost) );
355 HANDLE_CUDA_ERROR( cudaMemcpyAsync(V, D_V, sizeV, cudaMemcpyDeviceToHost) );
Note
Since we turned on weighted truncation options in this example, we need to restore the tensor descriptors for U and V if we wish to perform the same computation multiple times.
After the computation, we still need to free up all resources.