JAX Extension APIs#

cuQuantum Python JAX is an experimental extension for cuQuantum Python, designed to provide selected functionality of the cuQuantum SDK in a JAX-compatible manner. This extension allows JAX-based frameworks to directly integrate with the cuQuantum API. Currently, cuQuantum Python JAX offers a JAX interface to the Operator Action API from cuDensityMat, via the operator_action() function of the cuquantum.densitymat.jax module.

Release Notes#

cuQuantum Python JAX v0.0.4#

  • New features:

    • Added support for vector-jacobian product (VJP) transformation of operator action with batched input operators and batched input states.

    • Users can now differentiate with respect to parameters implicit in the input operator or input state instead of explicitly specifying parameter gradients in gradient callbacks. See example8_gradient_attachment.py for an example.

  • Bugs fixed:

    • Fixed a bug that caused double freeing of cuDensityMat library pointers.

    • Fixed a bug that triggered an insufficient workspace runtime error when executing regular operator action after evaluating its vector-jacobian-product (VJP) function transformation.

  • Other changes:

    • Users need to pass in jax.ShapeDtypeStruct objects instead of jax.Array objects for data buffers dynamically constructed by callbacks. The affected input arguments are:

    • The diagonal offsets argument of the ElementaryOperator constructor is renamed from offsets to diag_offsets, which is used to construct multidiagonal elementary operators.

    • Modification of an Operator after it has been used in an operator action is now disabled.

    • When installing cuQuantum Python JAX, the user needs to pass the --no-build-isolation option to pip and ensure that all build dependencies are pre-installed.

cuQuantum Python JAX v0.0.3#

  • Previously, cuQuantum Python JAX set jax_enable_x64=True as a side effect on import. Now, users must set jax_enable_x64 to True before importing the cuQuantum Python JAX module.

Compatibility notes:

  • cuQuantum Python JAX now supports CUDA 13 in addition to CUDA 12.

  • cuQuantum Python JAX supports JAX version >=0.8.0 and <0.9.0 for CUDA 13

  • cuQuantum Python JAX supports JAX version >=0.5.0 and <0.7.0 for CUDA 12

cuQuantum Python JAX v0.0.2#

  • Bugs fixed:

    • Fixed an issue with packaging of the cuQuantum Python JAX extension which rendered the package uninstallable in certain situations.

cuQuantum Python JAX v0.0.1#

  • Initial release of the JAX extension exposes the Operator Action API cuquantum.densitymat.jax.operator_action() from cuDensityMat to enable integration of cuQuantum Python with JAX-based quantum dynamics simulation frameworks.

  • Known issues:

    • Multiple structurally equivalent operators, which involve callback functions for elementary/matrix operators, will result in an undefined behavior when their action on a quantum state is evaluated in consecutive calls in the same python interpreter session. As a workaround, the user may call jax.clear_caches() in between consecutive calls to operator_action() involving these operators.

    • Providing data buffers that may be overwritten by callbacks but have been initialized with the same initial values will result in an undefined behavior inside a jax.jit scope. As a workaround, the user should ensure arrays are initialized with different initial values when using them for dynamically constructed data buffers of different elementary/matrix operators.

Compatibility notes:

  • cuQuantum Python JAX supports CUDA 12.

  • cuQuantum Python JAX supports JAX version >=0.5.0 and <0.7.0.

API usage#

Basic usage#

For a basic usage, we need to first import JAX. cuQuantum Python JAX requires double-precision arrays for callback parameters, so we need to update "jax_enable_x64" to True in JAX configurations.

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

cuQuantum Python JAX currently supports the GPU version of JAX with version number >=0.5.0 and <0.7.0.

To leverage the operator_action() API to compute the action of an operator on a quantum state in quantum dynamics simulation, we need to construct both the operator and the state. Suppose we have a two-site system where the Lindbladian master equation is defined by

\[\begin{split}\dot{\rho} = \mathcal{L}[\rho] &= -i[H, \rho] + \sum_{i} \left( L_i\rho L_i^\dagger - \frac{1}{2} L_i^\dagger L_i \rho - \frac{1}{2} \rho L_i^\dagger L_i \right) \\ H &= h_{01} \\ \{L_i\} &= \{l_0, l_1\}\end{split}\]

An operator action describes the RHS of the master equation, where the Liouvillian operator acts on the input quantum state to yield the output quantum state. To construct the Liouvillian operator, we need to first define JAX arrays for the data buffers, wrap them as ElementaryOperators, and build up OperatorTerms for Hamiltonian and dissipators and finally the Liouvillian Operator.

from cuquantum.densitymat.jax import ElementaryOperator, OperatorTerm, Operator

dims = (3, 5)
dtype = jnp.complex128

# Elementary operators
h01_data = jax.random.normal(jax.random.key(0), (*dims, *dims)).astype(dtype)
l0_data = jax.random.normal(jax.random.key(1), (dims[0], dims[0])).astype(dtype)
l1_data = jax.random.normal(jax.random.key(2), (dims[1], dims[1])).astype(dtype)
l0d_data = l0_data.conj().T
l1d_data = l1_data.conj().T

h01 = ElementaryOperator(h01_data)
l0 = ElementaryOperator(l0_data)
l1 = ElementaryOperator(l1_data)
l0d = ElementaryOperator(l0d_data)
l1d = ElementaryOperator(l1d_data)

# Hamiltonian operator term
H = OperatorTerm(dims)
H.append([h01], modes=[0, 1], duals=[False, False])

# Dissipators operator term
Ls = OperatorTerm(dims)
for i, l, ld in zip(list(range(len(dims))), [l0, l1], [l0d, l1d]):
   Ls.append([l, ld], modes=[i, i], duals=[False, True])
   Ls.append([l, ld], modes=[i, i], duals=[False, False], coeff=-0.5)
   Ls.append([ld, l], modes=[i, i], duals=[True, True], coeff=-0.5)

# Liouvillian operator
liouvillian = Operator(dims)
liouvillian.append(H, dual=False, coeff=-1j)
liouvillian.append(H, dual=True, coeff=1j)

We also need to construct the input quantum state to the operator action. Here suppose we have a pure input quantum state initialized with random values. We need to convert it into a density matrix before passing into operator_action() since the Liouvillian contains dissipators.

key = jax.random.key(42)
psi0 = jax.random.normal(key, dims)
psi0 /= jnp.linalg.norm(psi0)
rho0 = jnp.einsum('ij,kl->ijkl', psi0, psi0.conj()).astype(dtype)

After these steps, we can invoke the operator action by passing the operator, the time variable and the state to operator_action() (even when the operator is not time-dependent, the time variable t is a required positional argument):

from cuquantum.densitymat.jax import operator_action

t = 0.0
rho1 = operator_action(liouvillian, t, rho0)

The array rho1 is the output quantum state from the operator action.

JIT compilation#

The operator_action() API function is compatible with JAX’s JIT transformation jax.jit. We can apply jax.jit to the API call directly, as in the following:

rho1 = jax.jit(operator_action)(liouvillian, t, rho0)

Alternatively, we can apply jax.jit to the entire function that calls operator_action(), as in the following:

@jax.jit
def main():
    # Define operator and state here.
    rho1 = operator_action(liouvillian, t, rho0)

Backward differentiation#

The operator_action() API supports JAX’s automatic differentiation transformations. The user needs to define callback and grad_callback for the elementary/matrix operators and/or coeff_callback and coeff_grad_callback for the coefficients in order to differentiate with respect to the callback parameters. Below we will demonstrate how to define a tensor callback callback for the elementary operator h01 and a tensor gradient callback grad_callback so that the parameter gradients can be computed. Note that all the data buffers inside the callbacks (storage, tensor_grad, params_grad) are reconstructed as CuPy arrays from the user-provided data buffers, so that any modifications to them need to invoke the corresponding CuPy functions.

import cupy as cp

from cuquantum.bindings import cudensitymat as cudm

# Define the regular callback function.
def f_h01_callback(t, args, storage):
    storage[:] = 0.0
    for m in range(storage.shape[0]):
        for n in range(storage.shape[1]):
            for p in range(storage.shape[2]):
                for q in range(storage.shape[3]):
                    storage[m, n, p, q] = (m + n) * (p + q) * cp.tan(args[0] * t)
                    if storage.dtype.kind == 'c':
                        storage[m, n, p, q] += 1j * (m + n) * (p + q) / cp.tan(args[0] * t)

# Define the gradient callback function.
def f_h01_grad_callback(t, args, tensor_grad, params_grad):
    for m in range(tensor_grad.shape[0]):
        for n in range(tensor_grad.shape[1]):
            for p in range(tensor_grad.shape[2]):
                for q in range(tensor_grad.shape[3]):
                    params_grad[0] += 2 * (
                        tensor_grad[m, n, p, q] * ((m + n) * (p + q) * t / cp.cos(args[0] * t) ** 2)
                        ).real
                    if tensor_grad.dtype.kind == 'c':
                        params_grad[0] += 2 * (
                           tensor_grad[m, n, p, q] * (-1j * (m + n) * (p + q) * t / cp.sin(args[0] * t) ** 2)
                           ).real

# Define the wrapped callbacks from the callback functions.
h01_callback = cudm.WrappedTensorCallback(f_h01_callback, cudm.CallbackDevice.GPU)
h01_grad_callback = cudm.WrappedTensorGradientCallback(f_h01_grad_callback, cudm.CallbackDevice.GPU)

# Define the ElementaryOperator with callback and grad_callback.
h01_data = jax.ShapeDtypeStruct((*dims, *dims), jnp.complex128)
h01 = ElementaryOperator(h01_data, callback=h01_callback, grad_callback=h01_grad_callback)

The Liouvillian operator is defined similarly to the last section where h01 is substituted with the new h01. Note that h01_data is a jax.ShapeDtypeStruct object instead of a jax.Array object as it is dynamically constructed from the callback function. The user needs to apply jax.vjp on operator_action() to first obtain the output quantum state from the regular operator action rho1 as well as the VJP function vjp_f. Then the user can pass in the output quantum state adjoint to the VJP function vjp_f to compute gradients with respect to the input parameters. The following snippet illustrates how to use jax.vjp on operator_action() to obtain the input quantum state adjoint rho0_adj and gradients with respect to the parameters params_grad.

rho1, vjp_f = jax.vjp(operator_action, liouvillian, t, rho0, params)
_, _, rho0_adj, params_grad = vjp_f(jnp.conj(rho1))  # first two return arguments are op and t

API reference#

Functions#

operator_action(op, t, state_in_bufs[, ...])

Compute the action of an operator on a state.

Objects#

ElementaryOperator(data[, callback, ...])

PyTree class for cuDensityMat's elementary operator.

MatrixOperator(data[, callback, grad_callback])

PyTree class for cuDensityMat's matrix operator.

OperatorTerm(dims)

PyTree class for cuDensityMat's operator term.

Operator(dims)

PyTree class for cuDensityMat's operator.