cuquantum.cutensornet.compute_gradients_backward¶
- cuquantum.cutensornet.compute_gradients_backward(intptr_t handle, intptr_t plan, raw_data_in, intptr_t output_gradient, gradients, int32_t accumulate_output, intptr_t work_desc, intptr_t stream)[source]¶
Computes the gradients of the network w.r.t. the input tensors whose gradients are required. The network must have been contracted and loaded in the
work_desc
CACHE. Operates only on networks with single slice and no singleton modes.- Parameters
handle (intptr_t) – Opaque handle holding cuTensorNet’s library context.
plan (intptr_t) – Encodes the execution of a tensor network contraction (see
create_contraction_plan()
andcontraction_autotune()
). Some internal meta-data may be updated upon contraction.raw_data_in (object) –
Array of N pointers (N being the number of input tensors specified in
create_network_descriptor()
):raw_data_in[i]
points to the data associated with the i-th input tensor (in device memory). It can be:output_gradient (intptr_t) – Gradient of the output tensor (in device memory). Must have the same memory layout (strides) as the output tensor of the tensor network.
gradients (object) –
Array of N pointers:
gradients[i]
points to the gradient data associated with the i-th input tensor in device memory. Settinggradients[i]
to null would skip computing the gradient of the i-th input tensor. Generated gradient data has the same memory layout (strides) as their corresponding input tensors. It can be:accumulate_output (int32_t) – If 0, write the gradient results into
gradients
; otherwise accumulates the results intogradients
.work_desc (intptr_t) – Opaque structure describing the workspace. The provided
CUTENSORNET_WORKSPACE_SCRATCH
workspace must bevalid
(the workspace pointer must be device accessible, seecutensornetMemspace_t
, and the workspace size must be the same as or larger than the minimum needed). Seeworkspace_compute_contraction_sizes()
,workspace_get_memory_size()
&workspace_set_memory()
. The providedCUTENSORNET_WORKSPACE_CACHE
workspace must bevalid
(the workspace pointer must be device accessible, seecutensornetMemspace_t
), and contains the cached intermediate tensors from the correspondingcontract_slices()
call. If a device memory handler is set, andwork_desc
is set to null, or the memory pointer inwork_desc
of either the workspace kinds is set to null, for both calls tocontract_slices()
andcompute_gradients_backward()
, memory will be drawn from the memory pool. Seecontract_slices()
for details.stream (intptr_t) – The CUDA stream on which the computation is performed.