network_compute_gradients_backward#
-
cuquantum.
bindings. cutensornet. network_compute_gradients_backward( - intptr_t handle,
- intptr_t network_desc,
- int32_t accumulate_output,
- intptr_t work_desc,
- intptr_t slice_group,
- intptr_t stream,
Computes the gradients of the network w.r.t. the input tensors whose gradients are required. The network must have been contracted and loaded in the
work_descCACHE. Operates only on networks with single slice and no singleton modes.- Parameters:
handle (intptr_t) – Opaque handle holding cuTensorNet’s library context.
network_desc (intptr_t) – The network descriptor whose specifed slices (
slice_group) gradients will be computed (seenetwork_prepare_gradients_backward()). Some internal meta-data may be updated upon contraction.accumulate_output (int32_t) – If 0, write the gradient results into gradients memory buffers; otherwise accumulates the results into gradients memory buffers.
work_desc (intptr_t) – Opaque structure describing the workspace. The provided
CUTENSORNET_WORKSPACE_SCRATCHworkspace must bevalid(the workspace pointer must be device accessible, seecutensornetMemspace_t, and the workspace size must be the same as or larger than the minimum needed). Seeworkspace_compute_contraction_sizes(),workspace_get_memory_size()&workspace_set_memory(). The providedCUTENSORNET_WORKSPACE_CACHEworkspace must bevalid(the workspace pointer must be device accessible, seecutensornetMemspace_t), and contains the cached intermediate tensors from the correspondingnetwork_contract()call. If a device memory handler is set, andwork_descis set to null, or the memory pointer inwork_descof either the workspace kinds is set to null, for both calls tonetwork_contract()andnetwork_compute_gradients_backward(), memory will be drawn from the memory pool. Seenetwork_contract()for details.slice_group (intptr_t) – Opaque object specifying the slices of the gradients to be computed (see
create_slice_group_from_id_range()andcutensornetCreateSliceGroupFromIDs()).If set to null, all slices will be computed..stream (intptr_t) – The CUDA stream on which the computation is performed.