network_compute_gradients_backward#

cuquantum.bindings.cutensornet.network_compute_gradients_backward(
intptr_t handle,
intptr_t network_desc,
int32_t accumulate_output,
intptr_t work_desc,
intptr_t slice_group,
intptr_t stream,
)[source]#

Computes the gradients of the network w.r.t. the input tensors whose gradients are required. The network must have been contracted and loaded in the work_desc CACHE. Operates only on networks with single slice and no singleton modes.

Parameters:
  • handle (intptr_t) – Opaque handle holding cuTensorNet’s library context.

  • network_desc (intptr_t) – The network descriptor whose specifed slices (slice_group) gradients will be computed (see network_prepare_gradients_backward()). Some internal meta-data may be updated upon contraction.

  • accumulate_output (int32_t) – If 0, write the gradient results into gradients memory buffers; otherwise accumulates the results into gradients memory buffers.

  • work_desc (intptr_t) – Opaque structure describing the workspace. The provided CUTENSORNET_WORKSPACE_SCRATCH workspace must be valid (the workspace pointer must be device accessible, see cutensornetMemspace_t, and the workspace size must be the same as or larger than the minimum needed). See workspace_compute_contraction_sizes(), workspace_get_memory_size() & workspace_set_memory(). The provided CUTENSORNET_WORKSPACE_CACHE workspace must be valid (the workspace pointer must be device accessible, see cutensornetMemspace_t), and contains the cached intermediate tensors from the corresponding network_contract() call. If a device memory handler is set, and work_desc is set to null, or the memory pointer in work_desc of either the workspace kinds is set to null, for both calls to network_contract() and network_compute_gradients_backward(), memory will be drawn from the memory pool. See network_contract() for details.

  • slice_group (intptr_t) – Opaque object specifying the slices of the gradients to be computed (see create_slice_group_from_id_range() and cutensornetCreateSliceGroupFromIDs()). If set to null, all slices will be computed..

  • stream (intptr_t) – The CUDA stream on which the computation is performed.