cuStabilizer-JAX#

cuStabilizer-JAX provides a JAX interface to selected cuStabilizer routines, via the cuquantum.stabilizer.jax module. Currently, it exposes matmul_gf2_spdn() for GF(2) sparse-dense matrix multiplication.

GF(2) sparse-dense matrix multiplication#

matmul_gf2_spdn() computes

\[C_{ij} = \sum_{l=0}^{k-1} A_{il} B_{lj} \pmod{2}\]

where \(A\) is a sparse matrix in CSR format (shape \(m \times k\)) and \(B\) is a bit-packed dense matrix (shape \(k \times n\)), yielding a bit-packed dense result \(C\) (shape \(m \times n\)). Addition and multiplication are over GF(2): each inner product reduces to bitwise AND of a row of \(A\) with a column of \(B\), then XOR-summed — equivalent to (A @ B) % 2 in floating-point arithmetic.

A complete, runnable example is at python/extensions/samples/stabilizer/example_matmul_gf2_spdn.py.

Input conventions#

The API accepts three device arrays:

  • A_rowoff(m+1,) uint64 CSR row offsets of \(A\).

  • A_colidx(>= nnz,) uint64 CSR column indices of \(A\). The kernel reads nnz = A_rowoff[m]; any trailing entries are ignored.

  • B_packed(k, n // 32) uint32 bit-packed dense \(B\), using the standard cuStabilizer bit-packed layout (see Bit-Packed Format): little-endian bit order, with n a multiple of 32. Pad \(B\) with zero columns before packing if necessary (see Walkthrough below).

Note

Column indices of \(A\) do not need to be sorted within each row. This is a key difference from the sparse-sparse variant (custabilizerGF2SparseSparseMatrixMultiply), which requires \(B\)’s column indices to be sorted in ascending order within each row.

The call issues a small implicit device-to-host synchronization to read A_rowoff[m] (8 bytes) in order to pass an explicit nnz to the C API. This is a fixed per-call overhead, not a full stream barrier.

Walkthrough#

Step 1 — Prepare A in CSR format.

Use jax.experimental.sparse.CSR.fromdense() for a JAX-native path, or cupyx.scipy.sparse.csr_matrix() / nvmath.bindings.cusparse cusparseDenseToSparse_* from existing GPU arrays. The indptr and indices fields must be cast to uint64:

def pack_a_csr(a_dense: jax.Array) -> tuple[jax.Array, jax.Array]:
    """Dense ``(m, k) uint8`` -> ``(rowoff, colidx) uint64`` CSR."""
    m, k = int(a_dense.shape[0]), int(a_dense.shape[1])
    a_i8 = a_dense.view(jnp.int8) if a_dense.dtype == jnp.uint8 else a_dense.astype(jnp.int8)
    csr = sparse.CSR.fromdense(a_i8, nse=m * k)
    return csr.indptr.astype(jnp.uint64), csr.indices.astype(jnp.uint64)

When \(A\) is already available as a cuquantum.stabilizer.BitMatrixCSR, the row-offset and column-index device arrays can be reused directly:

A_rowoff = jnp.asarray(bm.row_offsets,          dtype=jnp.uint64)
A_colidx = jnp.asarray(bm.col_indices[:bm.nnz], dtype=jnp.uint64)

Step 2 — Bit-pack B in little-endian order.

Round n up to the next multiple of 32, pad \(B\) with zero columns, then pack with numpy.packbits() using bitorder="little" and view as uint32:

def pack_b_dense(b_dense: np.ndarray) -> tuple[np.ndarray, int]:
    """Dense ``(k, n) uint8`` -> ``(k, n_pad // 32) uint32`` bit-packed."""
    k, n = b_dense.shape
    n_pad = ((n + 31) // 32) * 32
    if n_pad > n:
        b_dense = np.concatenate(
            [b_dense, np.zeros((k, n_pad - n), dtype=np.uint8)], axis=1,
        )
    packed_u8 = np.packbits(np.ascontiguousarray(b_dense), axis=1, bitorder="little")
    return packed_u8.view(np.uint32), n_pad

Step 3 — Call matmul_gf2_spdn().

Pass the padded n_pad (not the original n) as the n argument:

from cuquantum.stabilizer.jax import matmul_gf2_spdn

B_packed_jax = jnp.asarray(B_packed)
C_packed = matmul_gf2_spdn(A_rowoff, A_colidx, B_packed_jax, m=m, n=n_pad, k=k)
# C_packed: (m, n_pad // 32) uint32

Step 4 — Unpack C (optional, for inspection or downstream use):

def unpack_c(c_packed: jax.Array, m: int, n: int) -> np.ndarray:
    """Bit-packed ``(m, n_pad // 32) uint32`` -> ``(m, n) uint8``."""
    shifts = jnp.arange(32, dtype=jnp.uint32)
    expanded = ((c_packed[..., None] >> shifts) & jnp.uint32(1)).astype(jnp.uint8)
    return np.asarray(expanded.reshape(m, -1)[:, :n])

Composition with JAX transformations#

matmul_gf2_spdn() is registered as a single-output FFI primitive and supports the following JAX transformations:

  • jax.jit — directly, or wrapping a workflow that calls it. Chained matmuls fuse into one compiled launch.

  • jax.vmap — registered with vmap_method="sequential": the primitive is replayed once per batch element. Composes freely with jax.jit.

Gradients are not defined.

Performance notes#

matmul_gf2_spdn() is best suited when \(B\) is dense or has high density. When both \(A\) and \(B\) are sparse, the C-level custabilizerGF2SparseSparseMatrixMultiply API may offer better performance by operating entirely on CSR data without materialising a bit-packed \(B\). Note that the sparse-sparse variant requires \(B\)’s CSR column indices to be sorted in ascending order within each row, whereas matmul_gf2_spdn() does not require \(A\)’s column indices to be sorted.

API reference#

Functions#

matmul_gf2_spdn(A_rowoff, A_colidx, ...)

GF(2) sparse-dense matmul C = A @ B via cuStabilizer SpDn FFI.