JAX Extension APIs#

cuQuantum Python JAX is an experimental extension for cuQuantum Python, designed to provide selected functionality of the cuQuantum SDK in a JAX-compatible manner. This extension allows JAX-based frameworks to directly integrate with the cuQuantum API. Currently, cuQuantum Python JAX offers a JAX interface to the Operator Action API from cuDensityMat, via the operator_action() function of the cuquantum.densitymat.jax module.

Release Notes#

cuQuantum Python JAX v0.0.5#

  • New features:

    • Gradients can now be computed with respect to any JAX-traceable parameter entering the operator or state construction without defining gradient callbacks (see example6_grad.py for an example). All callback-related arguments have been removed from the public API, and hence data buffers previously constructed dynamically via callbacks must now be pre-computed statically before being passed to operator_action(). The removed arguments are:

  • Known issues:

    • Nested jax.vmap transformations are not currently supported.

    • jax.grad transformations inside jax.vmap are not currently supported.

cuQuantum Python JAX v0.0.4#

  • New features:

    • Added support for vector-jacobian product (VJP) transformation of operator action with batched input operators and batched input states.

    • Users can now differentiate with respect to parameters implicit in the input operator or input state instead of explicitly specifying parameter gradients in gradient callbacks. See example8_gradient_attachment.py for an example.

  • Bugs fixed:

    • Fixed a bug that caused double freeing of cuDensityMat library pointers.

    • Fixed a bug that triggered an insufficient workspace runtime error when executing regular operator action after evaluating its vector-jacobian-product (VJP) function transformation.

  • Other changes:

    • Users need to pass in jax.ShapeDtypeStruct objects instead of jax.Array objects for data buffers dynamically constructed by callbacks. The affected input arguments are:

    • The diagonal offsets argument of the ElementaryOperator constructor is renamed from offsets to diag_offsets, which is used to construct multidiagonal elementary operators.

    • Modification of an Operator after it has been used in an operator action is now disabled.

    • When installing cuQuantum Python JAX, the user needs to pass the --no-build-isolation option to pip and ensure that all build dependencies are pre-installed.

cuQuantum Python JAX v0.0.3#

  • Previously, cuQuantum Python JAX set jax_enable_x64=True as a side effect on import. Now, users must set jax_enable_x64 to True before importing the cuQuantum Python JAX module.

Compatibility notes:

  • cuQuantum Python JAX now supports CUDA 13 in addition to CUDA 12.

  • cuQuantum Python JAX supports JAX version >=0.8.0 and <0.9.0 for CUDA 13

  • cuQuantum Python JAX supports JAX version >=0.5.0 and <0.7.0 for CUDA 12

cuQuantum Python JAX v0.0.2#

  • Bugs fixed:

    • Fixed an issue with packaging of the cuQuantum Python JAX extension which rendered the package uninstallable in certain situations.

cuQuantum Python JAX v0.0.1#

  • Initial release of the JAX extension exposes the Operator Action API cuquantum.densitymat.jax.operator_action() from cuDensityMat to enable integration of cuQuantum Python with JAX-based quantum dynamics simulation frameworks.

  • Known issues:

    • Multiple structurally equivalent operators, which involve callback functions for elementary/matrix operators, will result in an undefined behavior when their action on a quantum state is evaluated in consecutive calls in the same python interpreter session. As a workaround, the user may call jax.clear_caches() in between consecutive calls to operator_action() involving these operators.

    • Providing data buffers that may be overwritten by callbacks but have been initialized with the same initial values will result in an undefined behavior inside a jax.jit scope. As a workaround, the user should ensure arrays are initialized with different initial values when using them for dynamically constructed data buffers of different elementary/matrix operators.

Compatibility notes:

  • cuQuantum Python JAX supports CUDA 12.

  • cuQuantum Python JAX supports JAX version >=0.5.0 and <0.7.0.

API Usage#

cuQuantum Python JAX provides a single entry point operator_action() for the action of an operator on an input state. The operator action corresponds to the same concept in cuDensityMat and typically represents the right hand side of the master equation in quantum dynamics. For example, in the Lindbladian form of the master equation:

\[\dot{\rho} = \mathcal{L}[\rho] = -i[H, \rho] + \sum_{i} \left( L_i\rho L_i^\dagger - \frac{1}{2} L_i^\dagger L_i \rho - \frac{1}{2} \rho L_i^\dagger L_i \right)\]

rho1 = operator_action(liouvillian, rho0) corresponds to a single right hand side action \(\rho_1 = \mathcal{L}[\rho_0]\).

Supported JAX transformations#

The operator_action() function is compatible with the following JAX transformations:

  • jax.jit — JIT compilation applied directly to operator_action() or to an entire workflow that calls it.

  • jax.vmap — Batching over input states, coefficients associated with an operator product or operator term, or elementary or matrix operator data buffers.

  • jax.grad / jax.value_and_grad — Backward differentiation with respect to any JAX-traceable parameter entering the operator or state construction.

The following compositions are also supported:

  • jax.jit applied on top of jax.vmap(operator_action, ...).

  • jax.grad applied on top of jax.vmap(operator_action, ...).

  • jax.jit applied on top of jax.grad, optionally with an inner jax.vmap.

The following transformation combinations are not currently supported:

  • Nested jax.vmap (i.e., jax.vmap inside another jax.vmap).

  • jax.grad inside jax.vmap (i.e., jax.vmap(jax.grad(f), ...)).

Basic usage#

For a basic usage, we first need to import JAX. cuQuantum Python JAX requires double-precision arrays, so we need to update "jax_enable_x64" to True in JAX configurations before importing cuQuantum Python JAX.

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

cuQuantum Python JAX currently supports the GPU version of JAX with version number >=0.5.0 and <0.7.0.

To leverage the operator_action() API to compute the action of an operator on a quantum state, we need to define the operator and the state. Suppose we want to simulate the Jaynes-Cummings model subject to dissipation on both the two-level system and the cavity mode:

\[\begin{split}H &= \frac{1}{2}\omega_0\sigma_z + \omega a^\dagger a + \lambda(\sigma_+ a + \sigma_- a^\dagger) \\ \{L_i\} &= \{\gamma \sigma_-, \kappa a\}\end{split}\]

To construct the Liouvillian operator, we need to define JAX arrays for the elementary operator data buffers, wrap them as ElementaryOperators, construct OperatorTerms for the Hamiltonian and dissipators, and finally assemble the Liouvillian Operator.

from cuquantum.densitymat.jax import ElementaryOperator, OperatorTerm, Operator

# Physical parameters
omega0 = 1.0  # qubit transition frequency
omega = 1.0  # cavity frequency
lam = 0.1  # Jaynes-Cummings coupling strength
gamma = 0.05  # qubit spontaneous emission rate
kappa = 0.01  # cavity photon loss rate
N_cav = 10  # cavity mode number
N_qubit = 2  # qubit mode number

# Hilbert space: mode 0 = qubit (dim 2), mode 1 = cavity (truncated at N_cav)
dims = (N_qubit, N_cav)
dtype = jnp.complex128

# Operator matrices
# Qubit (mode 0)
sz_data = jnp.array([[1.0, 0.0], [0.0, -1.0]], dtype=dtype)  # sigma_z
sp_data = jnp.array([[0.0, 1.0], [0.0, 0.0]], dtype=dtype)  # sigma_+
sm_data = jnp.array([[0.0, 0.0], [1.0, 0.0]], dtype=dtype)  # sigma_-

# Cavity (mode 1)
n_data = jnp.diag(jnp.arange(N_cav, dtype=dtype))  # a†a
a_data = jnp.diag(jnp.sqrt(jnp.arange(1, N_cav, dtype=dtype)), k=1)  # a
ad_data = jnp.diag(jnp.sqrt(jnp.arange(1, N_cav, dtype=dtype)), k=-1)  # a†

sz = ElementaryOperator(sz_data)
sp = ElementaryOperator(sp_data)
sm = ElementaryOperator(sm_data)
n = ElementaryOperator(n_data)
a = ElementaryOperator(a_data)
ad = ElementaryOperator(ad_data)

# Hamiltonian: (omega0/2) sz + omega * n + lam * (sp⊗a + sm⊗a†)
H = OperatorTerm(dims)
H.append([sz], modes=[0], duals=[False], coeff=omega0 / 2)
H.append([n], modes=[1], duals=[False], coeff=omega)
H.append([sp, a], modes=[0, 1], duals=[False, False], coeff=lam)
H.append([sm, ad], modes=[0, 1], duals=[False, False], coeff=lam)

# Lindblad dissipators: D[L]rho = L rho L† - (1/2) L†L rho - (1/2) rho L†L
Ls = OperatorTerm(dims)

# D[gamma * sigma_-] on mode 0 (qubit decay)
Ls.append([sm, sp], modes=[0, 0], duals=[False, True], coeff=gamma**2)
Ls.append([sm, sp], modes=[0, 0], duals=[False, False], coeff=-0.5 * gamma**2)
Ls.append([sp, sm], modes=[0, 0], duals=[True, True], coeff=-0.5 * gamma**2)

# D[kappa * a] on mode 1 (photon loss)
Ls.append([a, ad], modes=[1, 1], duals=[False, True], coeff=kappa**2)
Ls.append([a, ad], modes=[1, 1], duals=[False, False], coeff=-0.5 * kappa**2)
Ls.append([ad, a], modes=[1, 1], duals=[True, True], coeff=-0.5 * kappa**2)

# Liouvillian: L(rho) = -i[H, rho] + sum_i D[L_i](rho)
liouvillian = Operator(dims)
liouvillian.append(H, dual=False, coeff=-1.0j)
liouvillian.append(H, dual=True, coeff=1.0j)
liouvillian.append(Ls, dual=False, coeff=1.0)

We also need to define the input quantum state for the operator action. Here suppose we have a pure input quantum state with a single excitation in the cavity mode:

# State vector: |g, n=1> — qubit ground state (index 1), cavity Fock state n=1 (index 1)
psi0 = jnp.zeros(dims, dtype=dtype)
psi0 = psi0.at[1, 1].set(1.0)
rho0 = jnp.einsum('ij,kl->ijkl', psi0, psi0.conj())

Note that the states, both input and output, are simply JAX arrays rather than types provided by cuQuantum Python JAX.

After these steps, we can invoke the operator action by passing the operator and the state to operator_action():

from cuquantum.densitymat.jax import operator_action

rho1 = operator_action(liouvillian, rho0)

The array rho1 is the output state from the operator action.

JIT compilation#

The operator_action() API function is compatible with JAX’s JIT transformation jax.jit. We can apply jax.jit to the API call directly:

rho1 = jax.jit(operator_action)(liouvillian, rho0)

The more common usage is to apply jax.jit to the entire workflow that invokes operator_action(), as is typical for a master equation solver.

from functools import partial

@partial(jax.jit, static_argnames=("N_cav", "N_qubit", "dtype"))
def main(omega0, omega, lam, gamma, kappa, N_cav, N_qubit, dtype):

    # Define operator and state as shown in the Basic usage section.
    # ...

    rho1 = operator_action(liouvillian, rho0)
    return rho1

Batching#

Batching is supported in cuQuantum Python JAX through the jax.vmap transformation. The user can apply jax.vmap to the operator, the state, or both. Suppose we want to run the same simulation on two different input states, one with a single excitation in the cavity mode and the other with an excitation in the qubit mode. The state construction block can be modified as follows:

# State vector: |g, n=1> — qubit ground state (index 1), cavity Fock state n=1 (index 1)
psi0a = jnp.zeros(dims, dtype=dtype)
psi0a = psi0a.at[1, 1].set(1.0)
rho0a = jnp.einsum('ij,kl->ijkl', psi0a, psi0a.conj())

# State vector: |e, n=0> — qubit excited state (index 0), cavity Fock state n=0 (index 0)
psi0b = jnp.zeros(dims, dtype=dtype)
psi0b = psi0b.at[0, 0].set(1.0)
rho0b = jnp.einsum('ij,kl->ijkl', psi0b, psi0b.conj())

rho0 = jnp.stack([rho0a, rho0b])

The operator_action() invocation needs to be wrapped in a jax.vmap transformation with the input operator’s in_axes on the state specified as 0.

rho1 = jax.vmap(operator_action, in_axes=(None, 0))(liouvillian, rho0)

The array rho1 is the output state from the operator action. The shape of rho1 is (2, *dims, *dims), where the first dimension is the batch dimension.

We can also have a batched operator with the same batch size as the batched state. Batching can be introduced in any of the three components: (1) Elementary or matrix operator data buffers (2) Coefficients associated with an elementary or matrix operator products in an operator term (3) Coefficients associated with an operator term in an operator. For more details of batching, please refer to the cuDensityMat documentation.

Below we will demonstrate an operator with batch size 2 on operator product coefficients in the Hamiltonian operator term:

# Hamiltonian: (omega0/2) sz + omega * n + lam * (sp⊗a + sm⊗a†)
# The second batch is defined by doubling the values of the first batch.
H = OperatorTerm(dims)
H.append([sz], modes=[0], duals=[False], coeff=jnp.array([omega0 / 2, omega0], dtype=dtype))
H.append([n], modes=[1], duals=[False], coeff=jnp.array([omega, 2 * omega], dtype=dtype))
H.append([sp, a], modes=[0, 1], duals=[False, False], coeff=jnp.array([lam, 2 * lam], dtype=dtype))
H.append([sm, ad], modes=[0, 1], duals=[False, False], coeff=jnp.array([lam, 2 * lam], dtype=dtype))

In the operator_action() invocation step, the input operator’s in_axes on the liouvillian operator needs to be specified as liouvillian.in_axes:

rho1 = jax.vmap(operator_action, in_axes=(liouvillian.in_axes, 0))(liouvillian, rho0)

Additionally, jax.jit can be applied on top of jax.vmap to further improve performance.

Backward differentiation#

The operator_action() API supports JAX’s automatic differentiation. Since operators and states are constructed from JAX arrays, gradients can be computed with respect to any parameter that enters the operator or state construction as a JAX-traceable value — including elementary or matrix operator data buffers, operator product coefficients in an operator term, operator term coefficients in an operator, and the input state itself. No gradient callbacks are required.

To differentiate, wrap the operator construction and operator_action() call in a function and apply jax.grad. For example, to compute the gradient of an expectation value with respect to the coupling strength lam and the decay rate gamma from the Jaynes-Cummings model:

def expectation_value(lam, gamma):
    # Hamiltonian: (omega0/2) sz + omega * n + lam * (sp⊗a + sm⊗a†)
    H = OperatorTerm(dims)
    H.append([sz], modes=[0], duals=[False], coeff=omega0 / 2)
    H.append([n], modes=[1], duals=[False], coeff=omega)
    H.append([sp, a], modes=[0, 1], duals=[False, False], coeff=lam)
    H.append([sm, ad], modes=[0, 1], duals=[False, False], coeff=lam)

    # Lindblad dissipators: D[gamma * sigma_-] on mode 0 (qubit decay)
    Ls = OperatorTerm(dims)
    Ls.append([sm, sp], modes=[0, 0], duals=[False, True], coeff=gamma**2)
    Ls.append([sm, sp], modes=[0, 0], duals=[False, False], coeff=-0.5 * gamma**2)
    Ls.append([sp, sm], modes=[0, 0], duals=[True, True], coeff=-0.5 * gamma**2)

    # Liouvillian: L(rho) = -i[H, rho] + D[gamma * sigma_-](rho)
    liouvillian = Operator(dims)
    liouvillian.append(H, dual=False, coeff=-1.0j)
    liouvillian.append(H, dual=True, coeff=1.0j)
    liouvillian.append(Ls, dual=False, coeff=1.0)

    rho1 = operator_action(liouvillian, rho0)
    # Expectation value of cavity photon number: Tr(n_data * rho1)
    obs = jnp.kron(jnp.eye(N_qubit, dtype=dtype), n_data).reshape(dims + dims)
    return jnp.real(jnp.einsum('ijkl,klij->', obs, rho1))

lam_grad, gamma_grad = jax.grad(expectation_value, argnums=(0, 1))(lam, gamma)

This works equally well when the parameter enters operator tensor data directly (e.g., omega * sz_data) or the input state.

Backward differentiation is also supported for batched operator actions. Gradients flow through both the batched operator construction and the jax.vmap-wrapped operator_action() call:

def batched_expectation_value(lam, gamma):
    # Construct batched liouvillian and state as shown in the Batching section
    # ...
    rho1 = jax.vmap(operator_action, in_axes=(liouvillian.in_axes, 0))(liouvillian, rho0)
    # Expectation value of cavity photon number: Tr(n_data * rho1)
    obs = jnp.kron(jnp.eye(N_qubit, dtype=dtype), n_data).reshape(dims + dims)
    return jnp.real(jnp.einsum('ijkl,bklij->', obs, rho1))

lam_grad, gamma_grad = jax.grad(batched_expectation_value, argnums=(0, 1))(lam, gamma)

Additionally, jax.jit can be applied on top of jax.grad, optionally with an inner jax.vmap, to further improve performance.

API reference#

Functions#

operator_action(op, state_in_bufs[, device])

Compute the action of an operator on a state.

Objects#

ElementaryOperator(data[, diag_offsets])

PyTree class for cuDensityMat's elementary operator.

MatrixOperator(data)

PyTree class for cuDensityMat's matrix operator.

OperatorTerm(dims)

PyTree class for cuDensityMat's operator term.

Operator(dims)

PyTree class for cuDensityMat's operator.