flat_to_dict#
- cuequivariance_jax.ir_dict.flat_to_dict(irreps, data, *, layout='mul_ir')#
Convert a flat array to dict[Irrep, Array] with shape (…, mul, ir.dim).
Splits a contiguous array along the last axis into separate arrays per irrep, reshaping each to have explicit (multiplicity, irrep_dim) dimensions.
- Parameters:
- Returns:
Dictionary mapping each irrep to array with shape (…, mul, ir.dim).
- Return type:
Example
>>> import cuequivariance as cue >>> irreps = cue.Irreps(cue.O3, "128x0e + 64x1o") >>> batch = 32 >>> flat = jnp.ones((batch, irreps.dim)) >>> d = flat_to_dict(irreps, flat) >>> d[cue.O3(0, 1)].shape (32, 128, 1) >>> d[cue.O3(1, -1)].shape (32, 64, 3)