.. _python jax APIs: .. currentmodule:: cuquantum.densitymat.jax ****************** JAX Extension APIs ****************** cuQuantum Python JAX is an extension for cuQuantum Python, designed to provide selected functionality of the cuQuantum SDK in a JAX-compatible manner. This extension allows JAX-based frameworks to directly integrate with the cuQuantum API. In this initial release, cuQuantum Python JAX offers a JAX interface to the Operator Action API from cuDensityMat, via the :func:`operator_action` function of the :mod:`cuquantum.densitymat.jax` module. API usage ========= Basic usage ----------- For a basic usage, we need to first import JAX. cuQuantum Python JAX requires double-precision arrays for callback parameters, so we need to update ``"jax_enable_x64"`` to ``True`` in JAX configurations. .. code-block:: python import jax import jax.numpy as jnp jax.config.update("jax_enable_x64", True) cuQuantum Python JAX currently supports the GPU version of JAX with version number >=0.5.0 and <0.7.0. To leverage the :func:`operator_action` API to compute the action of an operator on a quantum state in quantum dynamics simulation, we need to construct both the operator and the state. Suppose we have a two-site system where the Lindbladian master equation is defined by .. math:: \dot{\rho} = \mathcal{L}[\rho] &= -i\{H, \rho\} + \sum_{L_i\in \{L_i\}} L_i\rho L_i^\dagger - \frac{1}{2} L_i^\dagger L_i \rho - \frac{1}{2} \rho L_i^\dagger L_i \\ H &= h_{01} \\ \{L_i\} &= \{l_0, l_1\} An operator action describes the RHS of the master equation, where the Liouvillian operator acts on the input quantum state to yield the output quantum state. To construct the Liouvillian operator, we need to first define JAX arrays for the data buffers, wrap them as :class:`ElementaryOperator`\ s, and build up :class:`OperatorTerm`\ s for Hamiltonian and dissipators and finally the Liouvillian :class:`Operator`. .. code-block:: python from cuquantum.densitymat.jax import ElementaryOperator, OperatorTerm, Operator dims = (3, 5) dtype = jnp.complex128 # Elementary operators h01_data = jax.random.normal(jax.random.key(0), (*dims, *dims)).astype(dtype) l0_data = jax.random.normal(jax.random.key(1), (dims[0], dims[0])).astype(dtype) l1_data = jax.random.normal(jax.random.key(2), (dims[1], dims[1])).astype(dtype) l0d_data = l0_data.conj().T l1d_data = l1_data.conj().T h01 = ElementaryOperator(h01_data) l0 = ElementaryOperator(l0_data) l1 = ElementaryOperator(l1_data) l0d = ElementaryOperator(l0d_data) l1d = ElementaryOperator(l1d_data) # Hamiltonian operator term H = OperatorTerm(dims) H.append([h01], modes=[0, 1], duals=[False]) # Dissipators operator term Ls = OperatorTerm(dims) for i, l, ld in zip(list(range(len(dims))), [l0, l1], [l0d, l1d]): Ls.append([l, ld], modes=[i, i], duals=[False, True]) Ls.append([l, ld], modes=[i, i], duals=[False, False], coeff=-0.5) Ls.append([ld, l], modes=[i, i], duals=[True, True], coeff=-0.5) # Liouvillian operator liouvillian = Operator(dims) liouvillian.append(H, dual=False, coeff=-1j) liouvillian.append(H, dual=True, coeff=1j) We also need to construct the input quantum state to the operator action. Here suppose we have a pure input quantum state initialized with random values. We need to convert it into a density matrix before passing into :func:`operator_action` since the Liouvillian contains dissipators. .. code-block:: python key = jax.random.key(42) psi0 = jax.random.normal(key, dims) psi0 /= jnp.linalg.norm(psi0) rho0 = jnp.einsum('ij,kl->ijkl', psi0, psi0.conj()).astype(dtype) After these steps, we can invoke the operator action by passing the operator, the time variable and the state to :func:`operator_action` (even when the operator is not time-dependent, the time variable ``t`` is a required positional argument): .. code-block:: python from cuquantum.densitymat.jax import operator_action t = 0.0 rho1 = operator_action(liouvillian, t, rho0) The array ``rho1`` is the output quantum state from the operator action. JIT compilation --------------- The :func:`operator_action` API function is compatible with JAX's JIT transformation ``jax.jit``. We can apply ``jax.jit`` to the API call directly, as in the following: .. code-block:: python rho1 = jax.jit(operator_action)(liouvillian, t, rho0) Alternatively, we can apply ``jax.jit`` to the entire function that calls :func:`operator_action`, as in the following: .. code-block:: python @jax.jit def main(): # Define operator and state here. rho1 = operator_action(liouvillian, t, rho0) Backward differentiation ------------------------ The :func:`operator_action` API supports JAX's automatic differentiation transformations. The user needs to define ``callback`` and ``grad_callback`` for the elementary/matrix operators and/or ``coeff_callback`` and ``coeff_grad_callback`` for the coefficients in order to differentiate with respect to the callback parameters. Below we will demonstrate how to define a tensor callback ``callback`` for the elementary operator ``h01`` and a tensor gradient callback ``grad_callback`` so that the parameter gradients can be computed. Note that all the data buffers inside the callbacks (``storage``, ``tensor_grad``, ``params_grad``) are reconstructed as CuPy arrays from the user-provided data buffers, so that any modifications to them need to invoke the corresponding CuPy functions. .. code-block:: python import cupy as cp from cuquantum.bindings import cudensitymat as cudm # Define the regular callback function. def h01_callback(t, args, storage): storage[:] = 0.0 for m in range(storage.shape[0]): for n in range(storage.shape[1]): for p in range(storage.shape[2]): for q in range(storage.shape[3]): storage[m, n, p, q] = (m + n) * (p + q) * cp.tan(args[0] * t) if storage.dtype.kind == 'c': storage[m, n, p, q] += 1j * (m + n) * (p + q) / cp.tan(args[0] * t) # Define the gradient callback function. def h01_grad_callback(t, args, tensor_grad, params_grad): for m in range(tensor_grad.shape[0]): for n in range(tensor_grad.shape[1]): for p in range(tensor_grad.shape[2]): for q in range(tensor_grad.shape[3]): params_grad[0] += 2 * ( tensor_grad[m, n, p, q] * ((m + n) * (p + q) * t / cp.cos(args[0] * t) ** 2) ).real if tensor_grad.dtype.kind == 'c': params_grad[0] += 2 * ( tensor_grad[m, n, p, q] * (-1j * (m + n) * (p + q) * t / cp.sin(args[0] * t) ** 2) ).real # Define the wrapped callbacks from the callback functions. h01_callback = cudm.WrappedTensorCallback(h01_callback, cudm.CallbackDevice.GPU) h01_grad_callback = cudm.WrappedTensorGradientCallback(h01_grad_callback, cudm.CallbackDevice.GPU) # Define the ElementaryOperator with callback and grad_callback. h01 = ElementaryOperator(h01_data, callback=h01_callback, grad_callback=h01_grad_callback) The Liouvillian operator is defined similarly to the last section where ``h01`` is substituted with the new ``h01``. The user needs to apply ``jax.vjp`` on :func:`operator_action` to first obtain the output quantum state from the regular operator action ``rho1`` as well as the VJP function ``vjp_f``. Then the user can pass in the output quantum state adjoint to the VJP function ``vjp_f`` to compute gradients with respect to the input parameters. The following snippet illustrates how to use ``jax.vjp`` on :func:`operator_action` to obtain the input quantum state adjoint ``rho0_adj`` and gradients with respect to the parameters ``params_grad``. .. code-block:: python rho1, vjp_f = jax.vjp(operator_action, liouvillian, t, rho0, params) _, _, rho0_adj, params_grad = vjp_f(jnp.conj(rho1)) # first two return arguments are op and t API reference ============= .. module:: cuquantum.densitymat.jax Functions --------- .. autosummary:: :toctree: generated/ operator_action Objects ------- .. autosummary:: :toctree: generated/ ElementaryOperator MatrixOperator OperatorTerm Operator