cuDensityMat-JAX#

cuDensityMat-JAX provides a JAX interface to the Operator Action API from cuDensityMat, via the operator_action() function of the cuquantum.densitymat.jax module. It is designed for use in JAX-based quantum dynamics simulation frameworks.

API Usage#

cuQuantum Python JAX provides a single entry point operator_action() for the action of an operator on an input state. The operator action corresponds to the same concept in cuDensityMat and typically represents the right hand side of the master equation in quantum dynamics. For example, in the Lindbladian form of the master equation:

\[\dot{\rho} = \mathcal{L}[\rho] = -i[H, \rho] + \sum_{i} \left( L_i\rho L_i^\dagger - \frac{1}{2} L_i^\dagger L_i \rho - \frac{1}{2} \rho L_i^\dagger L_i \right)\]

rho1 = operator_action(liouvillian, rho0) corresponds to a single right hand side action \(\rho_1 = \mathcal{L}[\rho_0]\).

Supported JAX transformations#

The operator_action() function is compatible with the following JAX transformations:

  • jax.jit — JIT compilation applied directly to operator_action() or to an entire workflow that calls it.

  • jax.vmap — Batching over input states, coefficients associated with an operator product or operator term, or elementary or matrix operator data buffers.

  • jax.grad / jax.value_and_grad — Backward differentiation with respect to any JAX-traceable parameter entering the operator or state construction.

The following compositions are also supported:

  • jax.jit can be applied inside or outside jax.vmap.

  • jax.jit can be applied inside or outside jax.grad / jax.vjp, optionally composed with jax.vmap.

The following transformation combinations are not currently supported:

  • Nested jax.vmap (i.e., jax.vmap inside another jax.vmap).

  • jax.grad inside jax.vmap (i.e., jax.vmap(jax.grad(f), ...)).

Basic usage#

For a basic usage, we first need to import JAX. cuQuantum Python JAX requires double-precision arrays, so we need to update "jax_enable_x64" to True in JAX configurations before importing cuQuantum Python JAX.

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

To leverage the operator_action() API to compute the action of an operator on a quantum state, we need to define the operator and the state. Suppose we want to simulate the Jaynes-Cummings model subject to dissipation on both the two-level system and the cavity mode:

\[\begin{split}H &= \frac{1}{2}\omega_0\sigma_z + \omega a^\dagger a + \lambda(\sigma_+ a + \sigma_- a^\dagger) \\ \{L_i\} &= \{\gamma \sigma_-, \kappa a\}\end{split}\]

To construct the Liouvillian operator, we need to define JAX arrays for the elementary operator data buffers, wrap them as ElementaryOperators, construct OperatorTerms for the Hamiltonian and dissipators, and finally assemble the Liouvillian Operator.

from cuquantum.densitymat.jax import ElementaryOperator, OperatorTerm, Operator

# Physical parameters
omega0 = 1.0  # qubit transition frequency
omega = 1.0  # cavity frequency
lam = 0.1  # Jaynes-Cummings coupling strength
gamma = 0.05  # qubit spontaneous emission rate
kappa = 0.01  # cavity photon loss rate
N_cav = 10  # cavity mode number
N_qubit = 2  # qubit mode number

# Hilbert space: mode 0 = qubit (dim 2), mode 1 = cavity (truncated at N_cav)
dims = (N_qubit, N_cav)
dtype = jnp.complex128

# Operator matrices
# Qubit (mode 0)
sz_data = jnp.array([[1.0, 0.0], [0.0, -1.0]], dtype=dtype)  # sigma_z
sp_data = jnp.array([[0.0, 1.0], [0.0, 0.0]], dtype=dtype)  # sigma_+
sm_data = jnp.array([[0.0, 0.0], [1.0, 0.0]], dtype=dtype)  # sigma_-

# Cavity (mode 1)
n_data = jnp.diag(jnp.arange(N_cav, dtype=dtype))  # a†a
a_data = jnp.diag(jnp.sqrt(jnp.arange(1, N_cav, dtype=dtype)), k=1)  # a
ad_data = jnp.diag(jnp.sqrt(jnp.arange(1, N_cav, dtype=dtype)), k=-1)  # a†

sz = ElementaryOperator(sz_data)
sp = ElementaryOperator(sp_data)
sm = ElementaryOperator(sm_data)
n = ElementaryOperator(n_data)
a = ElementaryOperator(a_data)
ad = ElementaryOperator(ad_data)

# Hamiltonian: (omega0/2) sz + omega * n + lam * (sp⊗a + sm⊗a†)
H = OperatorTerm(dims)
H.append([sz], modes=[0], duals=[False], coeff=omega0 / 2)
H.append([n], modes=[1], duals=[False], coeff=omega)
H.append([sp, a], modes=[0, 1], duals=[False, False], coeff=lam)
H.append([sm, ad], modes=[0, 1], duals=[False, False], coeff=lam)

# Lindblad dissipators: D[L]rho = L rho L† - (1/2) L†L rho - (1/2) rho L†L
Ls = OperatorTerm(dims)

# D[gamma * sigma_-] on mode 0 (qubit decay)
Ls.append([sm, sp], modes=[0, 0], duals=[False, True], coeff=gamma**2)
Ls.append([sm, sp], modes=[0, 0], duals=[False, False], coeff=-0.5 * gamma**2)
Ls.append([sp, sm], modes=[0, 0], duals=[True, True], coeff=-0.5 * gamma**2)

# D[kappa * a] on mode 1 (photon loss)
Ls.append([a, ad], modes=[1, 1], duals=[False, True], coeff=kappa**2)
Ls.append([a, ad], modes=[1, 1], duals=[False, False], coeff=-0.5 * kappa**2)
Ls.append([ad, a], modes=[1, 1], duals=[True, True], coeff=-0.5 * kappa**2)

# Liouvillian: L(rho) = -i[H, rho] + sum_i D[L_i](rho)
liouvillian = Operator(dims)
liouvillian.append(H, dual=False, coeff=-1.0j)
liouvillian.append(H, dual=True, coeff=1.0j)
liouvillian.append(Ls, dual=False, coeff=1.0)

We also need to define the input quantum state for the operator action. Here suppose we have a pure input quantum state with a single excitation in the cavity mode:

# State vector: |g, n=1> — qubit ground state (index 1), cavity Fock state n=1 (index 1)
psi0 = jnp.zeros(dims, dtype=dtype)
psi0 = psi0.at[1, 1].set(1.0)
rho0 = jnp.einsum('ij,kl->ijkl', psi0, psi0.conj())

Note that the states, both input and output, are simply JAX arrays rather than types provided by cuQuantum Python JAX.

After these steps, we can invoke the operator action by passing the operator and the state to operator_action():

from cuquantum.densitymat.jax import operator_action

rho1 = operator_action(liouvillian, rho0)

The array rho1 is the output state from the operator action.

JIT compilation#

The operator_action() API function is compatible with JAX’s JIT transformation jax.jit. We can apply jax.jit to the API call directly:

rho1 = jax.jit(operator_action)(liouvillian, rho0)

The more common usage is to apply jax.jit to the entire workflow that invokes operator_action(), as is typical for a master equation solver.

from functools import partial

@partial(jax.jit, static_argnames=("N_cav", "N_qubit", "dtype"))
def main(omega0, omega, lam, gamma, kappa, N_cav, N_qubit, dtype):

    # Define operator and state as shown in the Basic usage section.
    # ...

    rho1 = operator_action(liouvillian, rho0)
    return rho1

Batching#

Batching is supported in cuQuantum Python JAX through the jax.vmap transformation. The user can apply jax.vmap to the operator, the state, or both. Suppose we want to run the same simulation on two different input states, one with a single excitation in the cavity mode and the other with an excitation in the qubit mode. The state construction block can be modified as follows:

# State vector: |g, n=1> — qubit ground state (index 1), cavity Fock state n=1 (index 1)
psi0a = jnp.zeros(dims, dtype=dtype)
psi0a = psi0a.at[1, 1].set(1.0)
rho0a = jnp.einsum('ij,kl->ijkl', psi0a, psi0a.conj())

# State vector: |e, n=0> — qubit excited state (index 0), cavity Fock state n=0 (index 0)
psi0b = jnp.zeros(dims, dtype=dtype)
psi0b = psi0b.at[0, 0].set(1.0)
rho0b = jnp.einsum('ij,kl->ijkl', psi0b, psi0b.conj())

rho0 = jnp.stack([rho0a, rho0b])

The operator_action() invocation needs to be wrapped in a jax.vmap transformation with the input operator’s in_axes on the state specified as 0.

rho1 = jax.vmap(operator_action, in_axes=(None, 0))(liouvillian, rho0)

The array rho1 is the output state from the operator action. The shape of rho1 is (2, *dims, *dims), where the first dimension is the batch dimension.

We can also have a batched operator with the same batch size as the batched state. Batching can be introduced in any of the three components: (1) Elementary or matrix operator data buffers (2) Coefficients associated with an elementary or matrix operator products in an operator term (3) Coefficients associated with an operator term in an operator. For more details of batching, please refer to the cuDensityMat documentation.

Below we will demonstrate an operator with batch size 2 on operator product coefficients in the Hamiltonian operator term:

# Hamiltonian: (omega0/2) sz + omega * n + lam * (sp⊗a + sm⊗a†)
# The second batch is defined by doubling the values of the first batch.
H = OperatorTerm(dims)
H.append([sz], modes=[0], duals=[False], coeff=jnp.array([omega0 / 2, omega0], dtype=dtype))
H.append([n], modes=[1], duals=[False], coeff=jnp.array([omega, 2 * omega], dtype=dtype))
H.append([sp, a], modes=[0, 1], duals=[False, False], coeff=jnp.array([lam, 2 * lam], dtype=dtype))
H.append([sm, ad], modes=[0, 1], duals=[False, False], coeff=jnp.array([lam, 2 * lam], dtype=dtype))

In the operator_action() invocation step, the input operator’s in_axes on the liouvillian operator needs to be specified as liouvillian.in_axes:

rho1 = jax.vmap(operator_action, in_axes=(liouvillian.in_axes, 0))(liouvillian, rho0)

Additionally, jax.jit can be applied on top of jax.vmap to further improve performance.

Backward differentiation#

The operator_action() API supports JAX’s automatic differentiation. Since operators and states are constructed from JAX arrays, gradients can be computed with respect to any parameter that enters the operator or state construction as a JAX-traceable value — including elementary or matrix operator data buffers, operator product coefficients in an operator term, operator term coefficients in an operator, and the input state itself. No gradient callbacks are required.

To differentiate, wrap the operator construction and operator_action() call in a function and apply jax.grad. For example, to compute the gradient of an expectation value with respect to the coupling strength lam and the decay rate gamma from the Jaynes-Cummings model:

def expectation_value(lam, gamma):
    # Hamiltonian: (omega0/2) sz + omega * n + lam * (sp⊗a + sm⊗a†)
    H = OperatorTerm(dims)
    H.append([sz], modes=[0], duals=[False], coeff=omega0 / 2)
    H.append([n], modes=[1], duals=[False], coeff=omega)
    H.append([sp, a], modes=[0, 1], duals=[False, False], coeff=lam)
    H.append([sm, ad], modes=[0, 1], duals=[False, False], coeff=lam)

    # Lindblad dissipators: D[gamma * sigma_-] on mode 0 (qubit decay)
    Ls = OperatorTerm(dims)
    Ls.append([sm, sp], modes=[0, 0], duals=[False, True], coeff=gamma**2)
    Ls.append([sm, sp], modes=[0, 0], duals=[False, False], coeff=-0.5 * gamma**2)
    Ls.append([sp, sm], modes=[0, 0], duals=[True, True], coeff=-0.5 * gamma**2)

    # Liouvillian: L(rho) = -i[H, rho] + D[gamma * sigma_-](rho)
    liouvillian = Operator(dims)
    liouvillian.append(H, dual=False, coeff=-1.0j)
    liouvillian.append(H, dual=True, coeff=1.0j)
    liouvillian.append(Ls, dual=False, coeff=1.0)

    rho1 = operator_action(liouvillian, rho0)
    # Expectation value of cavity photon number: Tr(n_data * rho1)
    obs = jnp.kron(jnp.eye(N_qubit, dtype=dtype), n_data).reshape(dims + dims)
    return jnp.real(jnp.einsum('ijkl,klij->', obs, rho1))

lam_grad, gamma_grad = jax.grad(expectation_value, argnums=(0, 1))(lam, gamma)

This works equally well when the parameter enters operator tensor data directly (e.g., omega * sz_data) or the input state.

Backward differentiation is also supported for batched operator actions. Gradients flow through both the batched operator construction and the jax.vmap-wrapped operator_action() call:

def batched_expectation_value(lam, gamma):
    # Construct batched liouvillian and state as shown in the Batching section
    # ...
    rho1 = jax.vmap(operator_action, in_axes=(liouvillian.in_axes, 0))(liouvillian, rho0)
    # Expectation value of cavity photon number: Tr(n_data * rho1)
    obs = jnp.kron(jnp.eye(N_qubit, dtype=dtype), n_data).reshape(dims + dims)
    return jnp.real(jnp.einsum('ijkl,bklij->', obs, rho1))

lam_grad, gamma_grad = jax.grad(batched_expectation_value, argnums=(0, 1))(lam, gamma)

Additionally, jax.jit can be applied on top of jax.grad, optionally with an inner jax.vmap, to further improve performance.

API reference#

Functions#

operator_action(op, state_in_bufs[, device])

Compute the action of an operator on a state.

Objects#

ElementaryOperator(data[, diag_offsets])

PyTree class for cuDensityMat's elementary operator.

MatrixOperator(data)

PyTree class for cuDensityMat's matrix operator.

OperatorTerm(dims)

PyTree class for cuDensityMat's operator term.

Operator(dims)

PyTree class for cuDensityMat's operator.