cuStabilizer-JAX#
cuStabilizer-JAX provides a JAX interface to selected cuStabilizer routines,
via the cuquantum. module.
Currently, it exposes matmul_gf2_spdn() for GF(2) sparse-dense matrix multiplication.
GF(2) sparse-dense matrix multiplication#
matmul_gf2_spdn() computes
where \(A\) is a sparse matrix in CSR format (shape \(m \times k\)) and
\(B\) is a bit-packed dense matrix (shape \(k \times n\)), yielding a
bit-packed dense result \(C\) (shape \(m \times n\)).
Addition and multiplication are over GF(2): each inner product reduces to
bitwise AND of a row of \(A\) with a column of \(B\), then XOR-summed — equivalent
to (A @ B) % 2 in floating-point arithmetic.
A complete, runnable example is at python/extensions/samples/stabilizer/example_matmul_gf2_spdn.py.
Input conventions#
The API accepts three device arrays:
A_rowoff —
(m+1,)uint64CSR row offsets of \(A\).A_colidx —
(>= nnz,)uint64CSR column indices of \(A\). The kernel readsnnz = A_rowoff[m]; any trailing entries are ignored.B_packed —
(k, n // 32)uint32bit-packed dense \(B\), using the standard cuStabilizer bit-packed layout (see Bit-Packed Format): little-endian bit order, withna multiple of 32. Pad \(B\) with zero columns before packing if necessary (see Walkthrough below).
Note
Column indices of \(A\) do not need to be sorted within each row.
This is a key difference from the sparse-sparse variant
(custabilizerGF2SparseSparseMatrixMultiply), which requires \(B\)’s
column indices to be sorted in ascending order within each row.
The call issues a small implicit device-to-host synchronization to read
A_rowoff[m] (8 bytes) in order to pass an explicit nnz to the C API.
This is a fixed per-call overhead, not a full stream barrier.
Walkthrough#
Step 1 — Prepare A in CSR format.
Use jax.experimental.sparse.CSR.fromdense() for a JAX-native path, or
cupyx.scipy.sparse.csr_matrix() / nvmath.bindings.cusparse
cusparseDenseToSparse_* from existing GPU arrays.
The indptr and indices fields must be cast to uint64:
def pack_a_csr(a_dense: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Dense ``(m, k) uint8`` -> ``(rowoff, colidx) uint64`` CSR."""
m, k = int(a_dense.shape[0]), int(a_dense.shape[1])
a_i8 = a_dense.view(jnp.int8) if a_dense.dtype == jnp.uint8 else a_dense.astype(jnp.int8)
csr = sparse.CSR.fromdense(a_i8, nse=m * k)
return csr.indptr.astype(jnp.uint64), csr.indices.astype(jnp.uint64)
When \(A\) is already available as a cuquantum.,
the row-offset and column-index device arrays can be reused directly:
A_rowoff = jnp.asarray(bm.row_offsets, dtype=jnp.uint64)
A_colidx = jnp.asarray(bm.col_indices[:bm.nnz], dtype=jnp.uint64)
Step 2 — Bit-pack B in little-endian order.
Round n up to the next multiple of 32, pad \(B\) with zero columns,
then pack with numpy.packbits() using bitorder="little" and view as uint32:
def pack_b_dense(b_dense: np.ndarray) -> tuple[np.ndarray, int]:
"""Dense ``(k, n) uint8`` -> ``(k, n_pad // 32) uint32`` bit-packed."""
k, n = b_dense.shape
n_pad = ((n + 31) // 32) * 32
if n_pad > n:
b_dense = np.concatenate(
[b_dense, np.zeros((k, n_pad - n), dtype=np.uint8)], axis=1,
)
packed_u8 = np.packbits(np.ascontiguousarray(b_dense), axis=1, bitorder="little")
return packed_u8.view(np.uint32), n_pad
Step 3 — Call matmul_gf2_spdn().
Pass the padded n_pad (not the original n) as the n argument:
from cuquantum.stabilizer.jax import matmul_gf2_spdn
B_packed_jax = jnp.asarray(B_packed)
C_packed = matmul_gf2_spdn(A_rowoff, A_colidx, B_packed_jax, m=m, n=n_pad, k=k)
# C_packed: (m, n_pad // 32) uint32
Step 4 — Unpack C (optional, for inspection or downstream use):
def unpack_c(c_packed: jax.Array, m: int, n: int) -> np.ndarray:
"""Bit-packed ``(m, n_pad // 32) uint32`` -> ``(m, n) uint8``."""
shifts = jnp.arange(32, dtype=jnp.uint32)
expanded = ((c_packed[..., None] >> shifts) & jnp.uint32(1)).astype(jnp.uint8)
return np.asarray(expanded.reshape(m, -1)[:, :n])
Composition with JAX transformations#
matmul_gf2_spdn() is registered as a single-output FFI primitive and
supports the following JAX transformations:
jax.jit— directly, or wrapping a workflow that calls it. Chained matmuls fuse into one compiled launch.jax.vmap— registered withvmap_method="sequential": the primitive is replayed once per batch element. Composes freely withjax.jit.
Gradients are not defined.
Performance notes#
matmul_gf2_spdn() is best suited when \(B\) is dense or has high density.
When both \(A\) and \(B\) are sparse, the C-level
custabilizerGF2SparseSparseMatrixMultiply API may offer better performance by
operating entirely on CSR data without materialising a bit-packed \(B\).
Note that the sparse-sparse variant requires \(B\)’s CSR column indices to be
sorted in ascending order within each row, whereas matmul_gf2_spdn() does not
require \(A\)’s column indices to be sorted.
API reference#
Functions#
|
GF(2) sparse-dense matmul |