JAX Extension APIs#

cuQuantum Python JAX is an experimental extension for cuQuantum Python, designed to provide selected functionality of the cuQuantum SDK in a JAX-compatible manner. This extension allows JAX-based frameworks to directly integrate with the cuQuantum API. Currently, cuQuantum Python JAX offers a JAX interface to the Operator Action API from cuDensityMat, via the operator_action() function of the cuquantum.densitymat.jax module. Since cuQuantum Python JAX is in an experimental stage, it supports only a subset of workflows expressible by its API, as described in Supported workflows. Unsupported workflows may result in an undefined behavior.

Release Notes#

cuQuantum Python JAX v0.0.3#

  • Previously, cuQuantum Python JAX set jax_enable_x64=True as a side effect on import. Now, users must set jax_enable_x64 to True before importing the cuQuantum Python JAX module.

Compatibility notes:

  • cuQuantum Python JAX now supports CUDA 13 in addition to CUDA 12.

  • cuQuantum Python JAX supports JAX version >=0.8.0 and <0.9.0 for CUDA 13

  • cuQuantum Python JAX supports JAX version >=0.5.0 and <0.7.0 for CUDA 12

cuQuantum Python JAX v0.0.2#

  • Bugs fixed:

    • Fixed an issue with packaging of the cuQuantum Python JAX extension which rendered the package uninstallable in certain situations.

cuQuantum Python JAX v0.0.1#

  • Initial release of the JAX extension exposes the Operator Action API cuquantum.densitymat.jax.operator_action() from cuDensityMat to enable integration of cuQuantum Python with JAX-based quantum dynamics simulation frameworks.

  • Known issues:

    • Multiple structurally equivalent operators, which involve callback functions for elementary/matrix operators, will result in an undefined behavior when their action on a quantum state is evaluated in consecutive calls in the same python interpreter session. As a workaround, the user may call jax.clear_caches() in between consecutive calls to operator_action() involving these operators.

    • Providing data buffers that may be overwritten by callbacks but have been initialized with the same initial values will result in an undefined behavior inside a jax.jit scope. As a workaround, the user should ensure arrays are initialized with different initial values when using them for dynamically constructed data buffers of different elementary/matrix operators.

Compatibility notes:

  • cuQuantum Python JAX supports CUDA 12.

  • cuQuantum Python JAX supports JAX version >=0.5.0 and <0.7.0.

API usage#

Basic usage#

For a basic usage, we need to first import JAX. cuQuantum Python JAX requires double-precision arrays for callback parameters, so we need to update "jax_enable_x64" to True in JAX configurations.

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

cuQuantum Python JAX currently supports the GPU version of JAX with version number >=0.5.0 and <0.7.0.

To leverage the operator_action() API to compute the action of an operator on a quantum state in quantum dynamics simulation, we need to construct both the operator and the state. Suppose we have a two-site system where the Lindbladian master equation is defined by

\[\begin{split}\dot{\rho} = \mathcal{L}[\rho] &= -i[H, \rho] + \sum_{i} \left( L_i\rho L_i^\dagger - \frac{1}{2} L_i^\dagger L_i \rho - \frac{1}{2} \rho L_i^\dagger L_i \right) \\ H &= h_{01} \\ \{L_i\} &= \{l_0, l_1\}\end{split}\]

An operator action describes the RHS of the master equation, where the Liouvillian operator acts on the input quantum state to yield the output quantum state. To construct the Liouvillian operator, we need to first define JAX arrays for the data buffers, wrap them as ElementaryOperators, and build up OperatorTerms for Hamiltonian and dissipators and finally the Liouvillian Operator.

from cuquantum.densitymat.jax import ElementaryOperator, OperatorTerm, Operator

dims = (3, 5)
dtype = jnp.complex128

# Elementary operators
h01_data = jax.random.normal(jax.random.key(0), (*dims, *dims)).astype(dtype)
l0_data = jax.random.normal(jax.random.key(1), (dims[0], dims[0])).astype(dtype)
l1_data = jax.random.normal(jax.random.key(2), (dims[1], dims[1])).astype(dtype)
l0d_data = l0_data.conj().T
l1d_data = l1_data.conj().T

h01 = ElementaryOperator(h01_data)
l0 = ElementaryOperator(l0_data)
l1 = ElementaryOperator(l1_data)
l0d = ElementaryOperator(l0d_data)
l1d = ElementaryOperator(l1d_data)

# Hamiltonian operator term
H = OperatorTerm(dims)
H.append([h01], modes=[0, 1], duals=[False])

# Dissipators operator term
Ls = OperatorTerm(dims)
for i, l, ld in zip(list(range(len(dims))), [l0, l1], [l0d, l1d]):
   Ls.append([l, ld], modes=[i, i], duals=[False, True])
   Ls.append([l, ld], modes=[i, i], duals=[False, False], coeff=-0.5)
   Ls.append([ld, l], modes=[i, i], duals=[True, True], coeff=-0.5)

# Liouvillian operator
liouvillian = Operator(dims)
liouvillian.append(H, dual=False, coeff=-1j)
liouvillian.append(H, dual=True, coeff=1j)

We also need to construct the input quantum state to the operator action. Here suppose we have a pure input quantum state initialized with random values. We need to convert it into a density matrix before passing into operator_action() since the Liouvillian contains dissipators.

key = jax.random.key(42)
psi0 = jax.random.normal(key, dims)
psi0 /= jnp.linalg.norm(psi0)
rho0 = jnp.einsum('ij,kl->ijkl', psi0, psi0.conj()).astype(dtype)

After these steps, we can invoke the operator action by passing the operator, the time variable and the state to operator_action() (even when the operator is not time-dependent, the time variable t is a required positional argument):

from cuquantum.densitymat.jax import operator_action

t = 0.0
rho1 = operator_action(liouvillian, t, rho0)

The array rho1 is the output quantum state from the operator action.

JIT compilation#

The operator_action() API function is compatible with JAX’s JIT transformation jax.jit. We can apply jax.jit to the API call directly, as in the following:

rho1 = jax.jit(operator_action)(liouvillian, t, rho0)

Alternatively, we can apply jax.jit to the entire function that calls operator_action(), as in the following:

@jax.jit
def main():
    # Define operator and state here.
    rho1 = operator_action(liouvillian, t, rho0)

Backward differentiation#

The operator_action() API supports JAX’s automatic differentiation transformations. The user needs to define callback and grad_callback for the elementary/matrix operators and/or coeff_callback and coeff_grad_callback for the coefficients in order to differentiate with respect to the callback parameters. Below we will demonstrate how to define a tensor callback callback for the elementary operator h01 and a tensor gradient callback grad_callback so that the parameter gradients can be computed. Note that all the data buffers inside the callbacks (storage, tensor_grad, params_grad) are reconstructed as CuPy arrays from the user-provided data buffers, so that any modifications to them need to invoke the corresponding CuPy functions.

import cupy as cp

from cuquantum.bindings import cudensitymat as cudm

# Define the regular callback function.
def h01_callback(t, args, storage):
    storage[:] = 0.0
    for m in range(storage.shape[0]):
        for n in range(storage.shape[1]):
            for p in range(storage.shape[2]):
                for q in range(storage.shape[3]):
                    storage[m, n, p, q] = (m + n) * (p + q) * cp.tan(args[0] * t)
                    if storage.dtype.kind == 'c':
                        storage[m, n, p, q] += 1j * (m + n) * (p + q) / cp.tan(args[0] * t)

# Define the gradient callback function.
def h01_grad_callback(t, args, tensor_grad, params_grad):
    for m in range(tensor_grad.shape[0]):
        for n in range(tensor_grad.shape[1]):
            for p in range(tensor_grad.shape[2]):
                for q in range(tensor_grad.shape[3]):
                    params_grad[0] += 2 * (
                        tensor_grad[m, n, p, q] * ((m + n) * (p + q) * t / cp.cos(args[0] * t) ** 2)
                        ).real
                    if tensor_grad.dtype.kind == 'c':
                        params_grad[0] += 2 * (
                           tensor_grad[m, n, p, q] * (-1j * (m + n) * (p + q) * t / cp.sin(args[0] * t) ** 2)
                           ).real

# Define the wrapped callbacks from the callback functions.
h01_callback = cudm.WrappedTensorCallback(h01_callback, cudm.CallbackDevice.GPU)
h01_grad_callback = cudm.WrappedTensorGradientCallback(h01_grad_callback, cudm.CallbackDevice.GPU)

# Define the ElementaryOperator with callback and grad_callback.
h01 = ElementaryOperator(h01_data, callback=h01_callback, grad_callback=h01_grad_callback)

The Liouvillian operator is defined similarly to the last section where h01 is substituted with the new h01. The user needs to apply jax.vjp on operator_action() to first obtain the output quantum state from the regular operator action rho1 as well as the VJP function vjp_f. Then the user can pass in the output quantum state adjoint to the VJP function vjp_f to compute gradients with respect to the input parameters. The following snippet illustrates how to use jax.vjp on operator_action() to obtain the input quantum state adjoint rho0_adj and gradients with respect to the parameters params_grad.

rho1, vjp_f = jax.vjp(operator_action, liouvillian, t, rho0, params)
_, _, rho0_adj, params_grad = vjp_f(jnp.conj(rho1))  # first two return arguments are op and t

Supported workflows#

Warning

  • JAX caches primitives (e.g. OperatorActionPrimitive and OperatorActionBackwardDiffPrimitive) and only recompiles them when input arguments change. In operator_action(), changes to the shape or dtype of input arguments, such as data buffers of elementary/matrix operators within the operator op or the input state buffer state_in_bufs, will trigger recompilation of the primitive across different operator actions. However, changes limited to callbacks for elementary/matrix operators or for coefficients currently do not trigger recompilation across operator actions, which can lead to incorrect results. In such cases, users should call jax.clear_caches() to force recompilation of the primitive.

  • The API requires that data buffers written to by the callbacks are passed in by users. Inside a jax.jit scope, since JAX aliases arrays with the same initial values, the user needs to ensure arrays are initialized with different initial values when using them for dynamically constructed data buffers of different elementary/matrix operators.

Given these caveats, currently the cuQuantum Python JAX API can only be used under the following restrictions:

  • a single instance of Operator per execution context

  • multiple instances of Operator per execution context but differentiated by either data type or Hilbert space dimensions

  • multiple instances of Operator that do not involve any callback functions for elementary/matrix operators or coefficients.

API reference#

Functions#

operator_action(op, t, state_in_bufs[, ...])

Compute the action of an operator on a state.

Objects#

ElementaryOperator(data[, callback, ...])

PyTree class for cuDensityMat's elementary operator.

MatrixOperator(data[, callback, grad_callback])

PyTree class for cuDensityMat's matrix operator.

OperatorTerm(dims)

PyTree class for cuDensityMat's operator term.

Operator(dims)

PyTree class for cuDensityMat's operator.