JAX Extension APIs#
cuQuantum Python JAX is an experimental extension for cuQuantum Python, designed to provide selected functionality
of the cuQuantum SDK in a JAX-compatible manner. This extension allows JAX-based frameworks to directly
integrate with the cuQuantum API. Currently, cuQuantum Python JAX offers a JAX interface
to the Operator Action API from cuDensityMat, via the operator_action() function of the
cuquantum. module.
Release Notes#
cuQuantum Python JAX v0.0.5#
New features:
Gradients can now be computed with respect to any JAX-traceable parameter entering the operator or state construction without defining gradient callbacks (see example6_grad.py for an example). All callback-related arguments have been removed from the public API, and hence data buffers previously constructed dynamically via callbacks must now be pre-computed statically before being passed to
operator_action(). The removed arguments are:operator_action()removedtandparamsarguments.ElementaryOperatorandMatrixOperatorremovedcallbackandgrad_callbackarguments.OperatorTerm.append()andOperator.append()removedcoeff_callbackandcoeff_grad_callbackarguments.
Known issues:
Nested
jax.vmaptransformations are not currently supported.jax.gradtransformations insidejax.vmapare not currently supported.
cuQuantum Python JAX v0.0.4#
New features:
Added support for vector-jacobian product (VJP) transformation of operator action with batched input operators and batched input states.
Users can now differentiate with respect to parameters implicit in the input operator or input state instead of explicitly specifying parameter gradients in gradient callbacks. See
example8_gradient_attachment.pyfor an example.
Bugs fixed:
Fixed a bug that caused double freeing of cuDensityMat library pointers.
Fixed a bug that triggered an insufficient workspace runtime error when executing regular operator action after evaluating its vector-jacobian-product (VJP) function transformation.
Other changes:
Users need to pass in
jax.ShapeDtypeStructobjects instead ofjax.Arrayobjects for data buffers dynamically constructed by callbacks. The affected input arguments are:The
dataargument ofElementaryOperator’s constructorThe
dataargument ofMatrixOperator’s constructorThe
total_coeffsargument ofOperatorTerm.append()The
total_coeffsargument ofOperator.append()
The diagonal offsets argument of the
ElementaryOperatorconstructor is renamed fromoffsetstodiag_offsets, which is used to construct multidiagonal elementary operators.Modification of an
Operatorafter it has been used in an operator action is now disabled.When installing cuQuantum Python JAX, the user needs to pass the
--no-build-isolationoption topipand ensure that all build dependencies are pre-installed.
cuQuantum Python JAX v0.0.3#
Previously, cuQuantum Python JAX set
jax_enable_x64=Trueas a side effect on import. Now, users must setjax_enable_x64toTruebefore importing the cuQuantum Python JAX module.
Compatibility notes:
cuQuantum Python JAX now supports CUDA 13 in addition to CUDA 12.
cuQuantum Python JAX supports JAX version >=0.8.0 and <0.9.0 for CUDA 13
cuQuantum Python JAX supports JAX version >=0.5.0 and <0.7.0 for CUDA 12
cuQuantum Python JAX v0.0.2#
Bugs fixed:
Fixed an issue with packaging of the cuQuantum Python JAX extension which rendered the package uninstallable in certain situations.
cuQuantum Python JAX v0.0.1#
Initial release of the JAX extension exposes the Operator Action API
cuquantum.from cuDensityMat to enable integration of cuQuantum Python with JAX-based quantum dynamics simulation frameworks.densitymat. jax. operator_action() Known issues:
Multiple structurally equivalent operators, which involve callback functions for elementary/matrix operators, will result in an undefined behavior when their action on a quantum state is evaluated in consecutive calls in the same python interpreter session. As a workaround, the user may call
jax.clear_caches()in between consecutive calls tooperator_action()involving these operators.Providing data buffers that may be overwritten by callbacks but have been initialized with the same initial values will result in an undefined behavior inside a
jax.jitscope. As a workaround, the user should ensure arrays are initialized with different initial values when using them for dynamically constructed data buffers of different elementary/matrix operators.
Compatibility notes:
cuQuantum Python JAX supports CUDA 12.
cuQuantum Python JAX supports JAX version >=0.5.0 and <0.7.0.
API Usage#
cuQuantum Python JAX provides a single entry point operator_action() for the action of an operator on an input
state. The operator action corresponds to the same concept in cuDensityMat
and typically represents the right hand side of the master equation in quantum dynamics. For example, in the
Lindbladian form of the master equation:
rho1 = operator_action(liouvillian, rho0) corresponds to a single right hand side action \(\rho_1 = \mathcal{L}[\rho_0]\).
Supported JAX transformations#
The operator_action() function is compatible with the following JAX transformations:
jax.jit— JIT compilation applied directly tooperator_action()or to an entire workflow that calls it.jax.vmap— Batching over input states, coefficients associated with an operator product or operator term, or elementary or matrix operator data buffers.jax.grad/jax.value_and_grad— Backward differentiation with respect to any JAX-traceable parameter entering the operator or state construction.
The following compositions are also supported:
jax.jitapplied on top ofjax.vmap(operator_action, ...).jax.gradapplied on top ofjax.vmap(operator_action, ...).jax.jitapplied on top ofjax.grad, optionally with an innerjax.vmap.
The following transformation combinations are not currently supported:
Nested
jax.vmap(i.e.,jax.vmapinside anotherjax.vmap).jax.gradinsidejax.vmap(i.e.,jax.vmap(jax.grad(f), ...)).
Basic usage#
For a basic usage, we first need to import JAX. cuQuantum Python JAX requires double-precision arrays, so we need to update "jax_enable_x64" to True in JAX configurations before importing cuQuantum Python JAX.
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
cuQuantum Python JAX currently supports the GPU version of JAX with version number >=0.5.0 and <0.7.0.
To leverage the operator_action() API to compute the action of an operator on a quantum state, we need to define the operator and the state. Suppose we want to simulate the Jaynes-Cummings
model subject to dissipation on both the two-level system and the cavity mode:
To construct the Liouvillian operator, we need to define
JAX arrays for the elementary operator data buffers, wrap them as ElementaryOperators, construct OperatorTerms for the Hamiltonian and dissipators, and finally assemble the Liouvillian Operator.
from cuquantum.densitymat.jax import ElementaryOperator, OperatorTerm, Operator
# Physical parameters
omega0 = 1.0 # qubit transition frequency
omega = 1.0 # cavity frequency
lam = 0.1 # Jaynes-Cummings coupling strength
gamma = 0.05 # qubit spontaneous emission rate
kappa = 0.01 # cavity photon loss rate
N_cav = 10 # cavity mode number
N_qubit = 2 # qubit mode number
# Hilbert space: mode 0 = qubit (dim 2), mode 1 = cavity (truncated at N_cav)
dims = (N_qubit, N_cav)
dtype = jnp.complex128
# Operator matrices
# Qubit (mode 0)
sz_data = jnp.array([[1.0, 0.0], [0.0, -1.0]], dtype=dtype) # sigma_z
sp_data = jnp.array([[0.0, 1.0], [0.0, 0.0]], dtype=dtype) # sigma_+
sm_data = jnp.array([[0.0, 0.0], [1.0, 0.0]], dtype=dtype) # sigma_-
# Cavity (mode 1)
n_data = jnp.diag(jnp.arange(N_cav, dtype=dtype)) # a†a
a_data = jnp.diag(jnp.sqrt(jnp.arange(1, N_cav, dtype=dtype)), k=1) # a
ad_data = jnp.diag(jnp.sqrt(jnp.arange(1, N_cav, dtype=dtype)), k=-1) # a†
sz = ElementaryOperator(sz_data)
sp = ElementaryOperator(sp_data)
sm = ElementaryOperator(sm_data)
n = ElementaryOperator(n_data)
a = ElementaryOperator(a_data)
ad = ElementaryOperator(ad_data)
# Hamiltonian: (omega0/2) sz + omega * n + lam * (sp⊗a + sm⊗a†)
H = OperatorTerm(dims)
H.append([sz], modes=[0], duals=[False], coeff=omega0 / 2)
H.append([n], modes=[1], duals=[False], coeff=omega)
H.append([sp, a], modes=[0, 1], duals=[False, False], coeff=lam)
H.append([sm, ad], modes=[0, 1], duals=[False, False], coeff=lam)
# Lindblad dissipators: D[L]rho = L rho L† - (1/2) L†L rho - (1/2) rho L†L
Ls = OperatorTerm(dims)
# D[gamma * sigma_-] on mode 0 (qubit decay)
Ls.append([sm, sp], modes=[0, 0], duals=[False, True], coeff=gamma**2)
Ls.append([sm, sp], modes=[0, 0], duals=[False, False], coeff=-0.5 * gamma**2)
Ls.append([sp, sm], modes=[0, 0], duals=[True, True], coeff=-0.5 * gamma**2)
# D[kappa * a] on mode 1 (photon loss)
Ls.append([a, ad], modes=[1, 1], duals=[False, True], coeff=kappa**2)
Ls.append([a, ad], modes=[1, 1], duals=[False, False], coeff=-0.5 * kappa**2)
Ls.append([ad, a], modes=[1, 1], duals=[True, True], coeff=-0.5 * kappa**2)
# Liouvillian: L(rho) = -i[H, rho] + sum_i D[L_i](rho)
liouvillian = Operator(dims)
liouvillian.append(H, dual=False, coeff=-1.0j)
liouvillian.append(H, dual=True, coeff=1.0j)
liouvillian.append(Ls, dual=False, coeff=1.0)
We also need to define the input quantum state for the operator action. Here suppose we have a pure input quantum state with a single excitation in the cavity mode:
# State vector: |g, n=1> — qubit ground state (index 1), cavity Fock state n=1 (index 1)
psi0 = jnp.zeros(dims, dtype=dtype)
psi0 = psi0.at[1, 1].set(1.0)
rho0 = jnp.einsum('ij,kl->ijkl', psi0, psi0.conj())
Note that the states, both input and output, are simply JAX arrays rather than types provided by cuQuantum Python JAX.
After these steps, we can invoke the operator action by passing the operator and the state to operator_action():
from cuquantum.densitymat.jax import operator_action
rho1 = operator_action(liouvillian, rho0)
The array rho1 is the output state from the operator action.
JIT compilation#
The operator_action() API function is compatible with JAX’s JIT transformation jax.jit.
We can apply jax.jit to the API call directly:
rho1 = jax.jit(operator_action)(liouvillian, rho0)
The more common usage is to apply jax.jit to the entire workflow that invokes operator_action(), as
is typical for a master equation solver.
from functools import partial
@partial(jax.jit, static_argnames=("N_cav", "N_qubit", "dtype"))
def main(omega0, omega, lam, gamma, kappa, N_cav, N_qubit, dtype):
# Define operator and state as shown in the Basic usage section.
# ...
rho1 = operator_action(liouvillian, rho0)
return rho1
Batching#
Batching is supported in cuQuantum Python JAX through the jax.vmap transformation. The user can apply jax.vmap
to the operator, the state, or both. Suppose we want to run the same simulation on two different input states, one with
a single excitation in the cavity mode and the other with an excitation in the qubit mode. The state construction block
can be modified as follows:
# State vector: |g, n=1> — qubit ground state (index 1), cavity Fock state n=1 (index 1)
psi0a = jnp.zeros(dims, dtype=dtype)
psi0a = psi0a.at[1, 1].set(1.0)
rho0a = jnp.einsum('ij,kl->ijkl', psi0a, psi0a.conj())
# State vector: |e, n=0> — qubit excited state (index 0), cavity Fock state n=0 (index 0)
psi0b = jnp.zeros(dims, dtype=dtype)
psi0b = psi0b.at[0, 0].set(1.0)
rho0b = jnp.einsum('ij,kl->ijkl', psi0b, psi0b.conj())
rho0 = jnp.stack([rho0a, rho0b])
The operator_action() invocation needs to be wrapped in a jax.vmap transformation with the input operator’s
in_axes on the state specified as 0.
rho1 = jax.vmap(operator_action, in_axes=(None, 0))(liouvillian, rho0)
The array rho1 is the output state from the operator action. The shape of rho1 is (2, *dims, *dims),
where the first dimension is the batch dimension.
We can also have a batched operator with the same batch size as the batched state. Batching can be introduced in any of the three components: (1) Elementary or matrix operator data buffers (2) Coefficients associated with an elementary or matrix operator products in an operator term (3) Coefficients associated with an operator term in an operator. For more details of batching, please refer to the cuDensityMat documentation.
Below we will demonstrate an operator with batch size 2 on operator product coefficients in the Hamiltonian operator term:
# Hamiltonian: (omega0/2) sz + omega * n + lam * (sp⊗a + sm⊗a†)
# The second batch is defined by doubling the values of the first batch.
H = OperatorTerm(dims)
H.append([sz], modes=[0], duals=[False], coeff=jnp.array([omega0 / 2, omega0], dtype=dtype))
H.append([n], modes=[1], duals=[False], coeff=jnp.array([omega, 2 * omega], dtype=dtype))
H.append([sp, a], modes=[0, 1], duals=[False, False], coeff=jnp.array([lam, 2 * lam], dtype=dtype))
H.append([sm, ad], modes=[0, 1], duals=[False, False], coeff=jnp.array([lam, 2 * lam], dtype=dtype))
In the operator_action() invocation step, the input operator’s in_axes on the liouvillian operator needs to be
specified as liouvillian.in_axes:
rho1 = jax.vmap(operator_action, in_axes=(liouvillian.in_axes, 0))(liouvillian, rho0)
Additionally, jax.jit can be applied on top of jax.vmap to further improve performance.
Backward differentiation#
The operator_action() API supports JAX’s automatic differentiation. Since operators and states
are constructed from JAX arrays, gradients can be computed with respect to any parameter that
enters the operator or state construction as a JAX-traceable value — including elementary or matrix operator data buffers,
operator product coefficients in an operator term, operator term coefficients in an operator, and the input state itself. No gradient callbacks are required.
To differentiate, wrap the operator construction and operator_action() call in a function and
apply jax.grad. For example, to compute the gradient of an expectation value with respect to
the coupling strength lam and the decay rate gamma from the Jaynes-Cummings model:
def expectation_value(lam, gamma):
# Hamiltonian: (omega0/2) sz + omega * n + lam * (sp⊗a + sm⊗a†)
H = OperatorTerm(dims)
H.append([sz], modes=[0], duals=[False], coeff=omega0 / 2)
H.append([n], modes=[1], duals=[False], coeff=omega)
H.append([sp, a], modes=[0, 1], duals=[False, False], coeff=lam)
H.append([sm, ad], modes=[0, 1], duals=[False, False], coeff=lam)
# Lindblad dissipators: D[gamma * sigma_-] on mode 0 (qubit decay)
Ls = OperatorTerm(dims)
Ls.append([sm, sp], modes=[0, 0], duals=[False, True], coeff=gamma**2)
Ls.append([sm, sp], modes=[0, 0], duals=[False, False], coeff=-0.5 * gamma**2)
Ls.append([sp, sm], modes=[0, 0], duals=[True, True], coeff=-0.5 * gamma**2)
# Liouvillian: L(rho) = -i[H, rho] + D[gamma * sigma_-](rho)
liouvillian = Operator(dims)
liouvillian.append(H, dual=False, coeff=-1.0j)
liouvillian.append(H, dual=True, coeff=1.0j)
liouvillian.append(Ls, dual=False, coeff=1.0)
rho1 = operator_action(liouvillian, rho0)
# Expectation value of cavity photon number: Tr(n_data * rho1)
obs = jnp.kron(jnp.eye(N_qubit, dtype=dtype), n_data).reshape(dims + dims)
return jnp.real(jnp.einsum('ijkl,klij->', obs, rho1))
lam_grad, gamma_grad = jax.grad(expectation_value, argnums=(0, 1))(lam, gamma)
This works equally well when the parameter enters operator tensor data directly (e.g., omega * sz_data)
or the input state.
Backward differentiation is also supported for batched operator actions. Gradients flow through
both the batched operator construction and the jax.vmap-wrapped operator_action() call:
def batched_expectation_value(lam, gamma):
# Construct batched liouvillian and state as shown in the Batching section
# ...
rho1 = jax.vmap(operator_action, in_axes=(liouvillian.in_axes, 0))(liouvillian, rho0)
# Expectation value of cavity photon number: Tr(n_data * rho1)
obs = jnp.kron(jnp.eye(N_qubit, dtype=dtype), n_data).reshape(dims + dims)
return jnp.real(jnp.einsum('ijkl,bklij->', obs, rho1))
lam_grad, gamma_grad = jax.grad(batched_expectation_value, argnums=(0, 1))(lam, gamma)
Additionally, jax.jit can be applied on top of jax.grad, optionally with an inner jax.vmap, to further improve performance.
API reference#
Functions#
|
Compute the action of an operator on a state. |
Objects#
|
PyTree class for cuDensityMat's elementary operator. |
|
PyTree class for cuDensityMat's matrix operator. |
|
PyTree class for cuDensityMat's operator term. |
|
PyTree class for cuDensityMat's operator. |