network_compute_gradients_backward#
-
cuquantum.
bindings. cutensornet. network_compute_gradients_backward( - intptr_t handle,
- intptr_t network_desc,
- int32_t accumulate_output,
- intptr_t work_desc,
- intptr_t slice_group,
- intptr_t stream,
Computes the gradients of the network w.r.t. the input tensors whose gradients are required. The network must have been contracted and loaded in the
work_desc
CACHE. Operates only on networks with single slice and no singleton modes.- Parameters:
handle (intptr_t) – Opaque handle holding cuTensorNet’s library context.
network_desc (intptr_t) – The network descriptor whose specifed slices (
slice_group
) gradients will be computed (seenetwork_prepare_gradients_backward()
). Some internal meta-data may be updated upon contraction.accumulate_output (int32_t) – If 0, write the gradient results into gradients memory buffers; otherwise accumulates the results into gradients memory buffers.
work_desc (intptr_t) – Opaque structure describing the workspace. The provided
CUTENSORNET_WORKSPACE_SCRATCH
workspace must bevalid
(the workspace pointer must be device accessible, seecutensornetMemspace_t
, and the workspace size must be the same as or larger than the minimum needed). Seeworkspace_compute_contraction_sizes()
,workspace_get_memory_size()
&workspace_set_memory()
. The providedCUTENSORNET_WORKSPACE_CACHE
workspace must bevalid
(the workspace pointer must be device accessible, seecutensornetMemspace_t
), and contains the cached intermediate tensors from the correspondingnetwork_contract()
call. If a device memory handler is set, andwork_desc
is set to null, or the memory pointer inwork_desc
of either the workspace kinds is set to null, for both calls tonetwork_contract()
andnetwork_compute_gradients_backward()
, memory will be drawn from the memory pool. Seenetwork_contract()
for details.slice_group (intptr_t) – Opaque object specifying the slices of the gradients to be computed (see
create_slice_group_from_id_range()
andcutensornetCreateSliceGroupFromIDs()
).If set to null, all slices will be computed.
.stream (intptr_t) – The CUDA stream on which the computation is performed.