ElementaryOperator#

class cuquantum.densitymat.jax.ElementaryOperator(
data: Array | ShapeDtypeStruct,
callback: WrappedTensorCallback | None = None,
grad_callback: WrappedTensorGradientCallback | None = None,
diag_offsets: Tuple[int, ...] = (),
)[source]#

PyTree class for cuDensityMat’s elementary operator.

Methods

__init__(
data: Array | ShapeDtypeStruct,
callback: WrappedTensorCallback | None = None,
grad_callback: WrappedTensorGradientCallback | None = None,
diag_offsets: Tuple[int, ...] = (),
) None[source]#

Initialize an ElementaryOperator object.

Parameters:
  • data – Data specification of the elementary operator. If callback is None, data should be a jax.Array; otherwise, data should be a jax.ShapeDtypeStruct.

  • callback – Forward callback for the elementary operator.

  • grad_callback – Gradient callback for the elementary operator.

  • diag_offsets – Diagonal offsets of the elementary operator.

copy() ElementaryOperator[source]#

Copy the elementary operator.