dict_to_flat#
- cuequivariance_jax.ir_dict.dict_to_flat(irreps, x)#
Convert dict[Irrep, Array] back to a flat contiguous array.
Flattens the (multiplicity, irrep_dim) dimensions and concatenates all irreps.
- Parameters:
- Returns:
Flat array with shape (…, irreps.dim).
- Return type:
Example
>>> import cuequivariance as cue >>> irreps = cue.Irreps(cue.O3, "128x0e + 64x1o") >>> batch = 32 >>> d = {cue.O3(0, 1): jnp.ones((batch, 128, 1)), ... cue.O3(1, -1): jnp.ones((batch, 64, 3))} >>> flat = dict_to_flat(irreps, d) >>> flat.shape (32, 320)