JAX Extension APIs¶
cuQuantum Python JAX is an extension for cuQuantum Python, designed to provide selected functionality
of the cuQuantum SDK in a JAX-compatible manner. This extension allows JAX-based frameworks to directly
integrate with the cuQuantum API. In this initial release, cuQuantum Python JAX offers a JAX interface
to the Operator Action API from cuDensityMat, via the operator_action()
function of the
cuquantum.densitymat.jax
module.
API usage¶
Basic usage¶
For a basic usage, we need to first import JAX. cuQuantum Python JAX requires double-precision arrays
for callback parameters, so we need to update "jax_enable_x64"
to True
in JAX configurations.
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
cuQuantum Python JAX currently supports the GPU version of JAX with version number >=0.5.0 and <0.7.0.
To leverage the operator_action()
API to compute the action of an operator on a quantum state
in quantum dynamics simulation, we need to construct both the operator and the state. Suppose we have
a two-site system where the Lindbladian master equation is defined by
An operator action describes the RHS of the master equation, where the Liouvillian operator acts on
the input quantum state to yield the output quantum state. To construct the Liouvillian operator, we need to first define
JAX arrays for the data buffers, wrap them as ElementaryOperator
s, and build up OperatorTerm
s
for Hamiltonian and dissipators and finally the Liouvillian Operator
.
from cuquantum.densitymat.jax import ElementaryOperator, OperatorTerm, Operator
dims = (3, 5)
dtype = jnp.complex128
# Elementary operators
h01_data = jax.random.normal(jax.random.key(0), (*dims, *dims)).astype(dtype)
l0_data = jax.random.normal(jax.random.key(1), (dims[0], dims[0])).astype(dtype)
l1_data = jax.random.normal(jax.random.key(2), (dims[1], dims[1])).astype(dtype)
l0d_data = l0_data.conj().T
l1d_data = l1_data.conj().T
h01 = ElementaryOperator(h01_data)
l0 = ElementaryOperator(l0_data)
l1 = ElementaryOperator(l1_data)
l0d = ElementaryOperator(l0d_data)
l1d = ElementaryOperator(l1d_data)
# Hamiltonian operator term
H = OperatorTerm(dims)
H.append([h01], modes=[0, 1], duals=[False])
# Dissipators operator term
Ls = OperatorTerm(dims)
for i, l, ld in zip(list(range(len(dims))), [l0, l1], [l0d, l1d]):
Ls.append([l, ld], modes=[i, i], duals=[False, True])
Ls.append([l, ld], modes=[i, i], duals=[False, False], coeff=-0.5)
Ls.append([ld, l], modes=[i, i], duals=[True, True], coeff=-0.5)
# Liouvillian operator
liouvillian = Operator(dims)
liouvillian.append(H, dual=False, coeff=-1j)
liouvillian.append(H, dual=True, coeff=1j)
We also need to construct the input quantum state to the operator action. Here suppose we have a pure input
quantum state initialized with random values. We need to convert it into a density matrix before passing
into operator_action()
since the Liouvillian contains dissipators.
key = jax.random.key(42)
psi0 = jax.random.normal(key, dims)
psi0 /= jnp.linalg.norm(psi0)
rho0 = jnp.einsum('ij,kl->ijkl', psi0, psi0.conj()).astype(dtype)
After these steps, we can invoke the operator action by passing the operator, the time variable and
the state to operator_action()
(even when the operator is not time-dependent, the time variable t
is a required positional argument):
from cuquantum.densitymat.jax import operator_action
t = 0.0
rho1 = operator_action(liouvillian, t, rho0)
The array rho1
is the output quantum state from the operator action.
JIT compilation¶
The operator_action()
API function is compatible with JAX’s JIT transformation jax.jit
.
We can apply jax.jit
to the API call directly, as in the following:
rho1 = jax.jit(operator_action)(liouvillian, t, rho0)
Alternatively, we can apply jax.jit
to the entire function that calls operator_action()
,
as in the following:
@jax.jit
def main():
# Define operator and state here.
rho1 = operator_action(liouvillian, t, rho0)
Backward differentiation¶
The operator_action()
API supports JAX’s automatic differentiation transformations. The user needs
to define callback
and grad_callback
for the elementary/matrix operators and/or
coeff_callback
and coeff_grad_callback
for the coefficients in order to differentiate
with respect to the callback parameters. Below we will demonstrate how to define a tensor callback
callback
for the elementary operator h01
and a tensor gradient callback grad_callback
so that the
parameter gradients can be computed. Note that all the data buffers inside the callbacks (storage
,
tensor_grad
, params_grad
) are reconstructed as CuPy arrays from the user-provided data buffers,
so that any modifications to them need to invoke the corresponding CuPy functions.
import cupy as cp
from cuquantum.bindings import cudensitymat as cudm
# Define the regular callback function.
def h01_callback(t, args, storage):
storage[:] = 0.0
for m in range(storage.shape[0]):
for n in range(storage.shape[1]):
for p in range(storage.shape[2]):
for q in range(storage.shape[3]):
storage[m, n, p, q] = (m + n) * (p + q) * cp.tan(args[0] * t)
if storage.dtype.kind == 'c':
storage[m, n, p, q] += 1j * (m + n) * (p + q) / cp.tan(args[0] * t)
# Define the gradient callback function.
def h01_grad_callback(t, args, tensor_grad, params_grad):
for m in range(tensor_grad.shape[0]):
for n in range(tensor_grad.shape[1]):
for p in range(tensor_grad.shape[2]):
for q in range(tensor_grad.shape[3]):
params_grad[0] += 2 * (
tensor_grad[m, n, p, q] * ((m + n) * (p + q) * t / cp.cos(args[0] * t) ** 2)
).real
if tensor_grad.dtype.kind == 'c':
params_grad[0] += 2 * (
tensor_grad[m, n, p, q] * (-1j * (m + n) * (p + q) * t / cp.sin(args[0] * t) ** 2)
).real
# Define the wrapped callbacks from the callback functions.
h01_callback = cudm.WrappedTensorCallback(h01_callback, cudm.CallbackDevice.GPU)
h01_grad_callback = cudm.WrappedTensorGradientCallback(h01_grad_callback, cudm.CallbackDevice.GPU)
# Define the ElementaryOperator with callback and grad_callback.
h01 = ElementaryOperator(h01_data, callback=h01_callback, grad_callback=h01_grad_callback)
The Liouvillian operator is defined similarly to the last section where h01
is substituted with the
new h01
. The user needs to apply jax.vjp
on operator_action()
to first obtain the output quantum state
from the regular operator action rho1
as well as the VJP function vjp_f
. Then the user can pass in the
output quantum state adjoint to the VJP function vjp_f
to compute gradients with respect to the input parameters.
The following snippet illustrates how to use jax.vjp
on operator_action()
to obtain the input quantum state
adjoint rho0_adj
and gradients with respect to the parameters params_grad
.
rho1, vjp_f = jax.vjp(operator_action, liouvillian, t, rho0, params)
_, _, rho0_adj, params_grad = vjp_f(jnp.conj(rho1)) # first two return arguments are op and t
API reference¶
Functions¶
|
Compute the action of an operator on a state. |
Objects¶
|
PyTree class for cuDensityMat's elementary operator. |
|
PyTree class for cuDensityMat's matrix operator. |
|
PyTree class for cuDensityMat's operator term. |
|
PyTree class for cuDensityMat's operator. |