flat_to_dict#

cuequivariance_jax.ir_dict.flat_to_dict(irreps, data, *, layout='mul_ir')#

Convert a flat array to dict[Irrep, Array] with shape (…, mul, ir.dim).

Splits a contiguous array along the last axis into separate arrays per irrep, reshaping each to have explicit (multiplicity, irrep_dim) dimensions.

Parameters:
  • irreps (Irreps) – Irreps specification for splitting.

  • data (Array) – Flat array with shape (…, irreps.dim).

  • layout (str) – Memory layout of the flat data. Either “mul_ir” (default) where data is ordered as (mul, ir.dim), or “ir_mul” where data is ordered as (ir.dim, mul).

Returns:

Dictionary mapping each irrep to array with shape (…, mul, ir.dim).

Return type:

dict[Irrep, Array]

Example

>>> import cuequivariance as cue
>>> irreps = cue.Irreps(cue.O3, "128x0e + 64x1o")
>>> batch = 32
>>> flat = jnp.ones((batch, irreps.dim))
>>> d = flat_to_dict(irreps, flat)
>>> d[cue.O3(0, 1)].shape
(32, 128, 1)
>>> d[cue.O3(1, -1)].shape
(32, 64, 3)