JAX Extension APIs

cuQuantum Python JAX is an extension for cuQuantum Python, designed to provide selected functionality of the cuQuantum SDK in a JAX-compatible manner. This extension allows JAX-based frameworks to directly integrate with the cuQuantum API. In this initial release, cuQuantum Python JAX offers a JAX interface to the Operator Action API from cuDensityMat, via the operator_action() function of the cuquantum.densitymat.jax module.

API usage

Basic usage

For a basic usage, we need to first import JAX. cuQuantum Python JAX requires double-precision arrays for callback parameters, so we need to update "jax_enable_x64" to True in JAX configurations.

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

cuQuantum Python JAX currently supports the GPU version of JAX with version number >=0.5.0 and <0.7.0.

To leverage the operator_action() API to compute the action of an operator on a quantum state in quantum dynamics simulation, we need to construct both the operator and the state. Suppose we have a two-site system where the Lindbladian master equation is defined by

\[\begin{split}\dot{\rho} = \mathcal{L}[\rho] &= -i\{H, \rho\} + \sum_{L_i\in \{L_i\}} L_i\rho L_i^\dagger - \frac{1}{2} L_i^\dagger L_i \rho - \frac{1}{2} \rho L_i^\dagger L_i \\ H &= h_{01} \\ \{L_i\} &= \{l_0, l_1\}\end{split}\]

An operator action describes the RHS of the master equation, where the Liouvillian operator acts on the input quantum state to yield the output quantum state. To construct the Liouvillian operator, we need to first define JAX arrays for the data buffers, wrap them as ElementaryOperators, and build up OperatorTerms for Hamiltonian and dissipators and finally the Liouvillian Operator.

from cuquantum.densitymat.jax import ElementaryOperator, OperatorTerm, Operator

dims = (3, 5)
dtype = jnp.complex128

# Elementary operators
h01_data = jax.random.normal(jax.random.key(0), (*dims, *dims)).astype(dtype)
l0_data = jax.random.normal(jax.random.key(1), (dims[0], dims[0])).astype(dtype)
l1_data = jax.random.normal(jax.random.key(2), (dims[1], dims[1])).astype(dtype)
l0d_data = l0_data.conj().T
l1d_data = l1_data.conj().T

h01 = ElementaryOperator(h01_data)
l0 = ElementaryOperator(l0_data)
l1 = ElementaryOperator(l1_data)
l0d = ElementaryOperator(l0d_data)
l1d = ElementaryOperator(l1d_data)

# Hamiltonian operator term
H = OperatorTerm(dims)
H.append([h01], modes=[0, 1], duals=[False])

# Dissipators operator term
Ls = OperatorTerm(dims)
for i, l, ld in zip(list(range(len(dims))), [l0, l1], [l0d, l1d]):
   Ls.append([l, ld], modes=[i, i], duals=[False, True])
   Ls.append([l, ld], modes=[i, i], duals=[False, False], coeff=-0.5)
   Ls.append([ld, l], modes=[i, i], duals=[True, True], coeff=-0.5)

# Liouvillian operator
liouvillian = Operator(dims)
liouvillian.append(H, dual=False, coeff=-1j)
liouvillian.append(H, dual=True, coeff=1j)

We also need to construct the input quantum state to the operator action. Here suppose we have a pure input quantum state initialized with random values. We need to convert it into a density matrix before passing into operator_action() since the Liouvillian contains dissipators.

key = jax.random.key(42)
psi0 = jax.random.normal(key, dims)
psi0 /= jnp.linalg.norm(psi0)
rho0 = jnp.einsum('ij,kl->ijkl', psi0, psi0.conj()).astype(dtype)

After these steps, we can invoke the operator action by passing the operator, the time variable and the state to operator_action() (even when the operator is not time-dependent, the time variable t is a required positional argument):

from cuquantum.densitymat.jax import operator_action

t = 0.0
rho1 = operator_action(liouvillian, t, rho0)

The array rho1 is the output quantum state from the operator action.

JIT compilation

The operator_action() API function is compatible with JAX’s JIT transformation jax.jit. We can apply jax.jit to the API call directly, as in the following:

rho1 = jax.jit(operator_action)(liouvillian, t, rho0)

Alternatively, we can apply jax.jit to the entire function that calls operator_action(), as in the following:

@jax.jit
def main():
    # Define operator and state here.
    rho1 = operator_action(liouvillian, t, rho0)

Backward differentiation

The operator_action() API supports JAX’s automatic differentiation transformations. The user needs to define callback and grad_callback for the elementary/matrix operators and/or coeff_callback and coeff_grad_callback for the coefficients in order to differentiate with respect to the callback parameters. Below we will demonstrate how to define a tensor callback callback for the elementary operator h01 and a tensor gradient callback grad_callback so that the parameter gradients can be computed. Note that all the data buffers inside the callbacks (storage, tensor_grad, params_grad) are reconstructed as CuPy arrays from the user-provided data buffers, so that any modifications to them need to invoke the corresponding CuPy functions.

import cupy as cp

from cuquantum.bindings import cudensitymat as cudm

# Define the regular callback function.
def h01_callback(t, args, storage):
    storage[:] = 0.0
    for m in range(storage.shape[0]):
        for n in range(storage.shape[1]):
            for p in range(storage.shape[2]):
                for q in range(storage.shape[3]):
                    storage[m, n, p, q] = (m + n) * (p + q) * cp.tan(args[0] * t)
                    if storage.dtype.kind == 'c':
                        storage[m, n, p, q] += 1j * (m + n) * (p + q) / cp.tan(args[0] * t)

# Define the gradient callback function.
def h01_grad_callback(t, args, tensor_grad, params_grad):
    for m in range(tensor_grad.shape[0]):
        for n in range(tensor_grad.shape[1]):
            for p in range(tensor_grad.shape[2]):
                for q in range(tensor_grad.shape[3]):
                    params_grad[0] += 2 * (
                        tensor_grad[m, n, p, q] * ((m + n) * (p + q) * t / cp.cos(args[0] * t) ** 2)
                        ).real
                    if tensor_grad.dtype.kind == 'c':
                        params_grad[0] += 2 * (
                           tensor_grad[m, n, p, q] * (-1j * (m + n) * (p + q) * t / cp.sin(args[0] * t) ** 2)
                           ).real

# Define the wrapped callbacks from the callback functions.
h01_callback = cudm.WrappedTensorCallback(h01_callback, cudm.CallbackDevice.GPU)
h01_grad_callback = cudm.WrappedTensorGradientCallback(h01_grad_callback, cudm.CallbackDevice.GPU)

# Define the ElementaryOperator with callback and grad_callback.
h01 = ElementaryOperator(h01_data, callback=h01_callback, grad_callback=h01_grad_callback)

The Liouvillian operator is defined similarly to the last section where h01 is substituted with the new h01. The user needs to apply jax.vjp on operator_action() to first obtain the output quantum state from the regular operator action rho1 as well as the VJP function vjp_f. Then the user can pass in the output quantum state adjoint to the VJP function vjp_f to compute gradients with respect to the input parameters. The following snippet illustrates how to use jax.vjp on operator_action() to obtain the input quantum state adjoint rho0_adj and gradients with respect to the parameters params_grad.

rho1, vjp_f = jax.vjp(operator_action, liouvillian, t, rho0, params)
_, _, rho0_adj, params_grad = vjp_f(jnp.conj(rho1))  # first two return arguments are op and t

API reference

Functions

operator_action(op, t, state_in_bufs[, ...])

Compute the action of an operator on a state.

Objects

ElementaryOperator(data[, callback, ...])

PyTree class for cuDensityMat's elementary operator.

MatrixOperator(data[, callback, grad_callback])

PyTree class for cuDensityMat's matrix operator.

OperatorTerm(dims)

PyTree class for cuDensityMat's operator term.

Operator(dims)

PyTree class for cuDensityMat's operator.