dict_to_flat#

cuequivariance_jax.ir_dict.dict_to_flat(irreps, x)#

Convert dict[Irrep, Array] back to a flat contiguous array.

Flattens the (multiplicity, irrep_dim) dimensions and concatenates all irreps.

Parameters:
  • irreps (Irreps) – Irreps specification defining the order.

  • x (dict[Irrep, Array]) – Dictionary with arrays of shape (…, mul, ir.dim).

Returns:

Flat array with shape (…, irreps.dim).

Return type:

Array

Example

>>> import cuequivariance as cue
>>> irreps = cue.Irreps(cue.O3, "128x0e + 64x1o")
>>> batch = 32
>>> d = {cue.O3(0, 1): jnp.ones((batch, 128, 1)),
...      cue.O3(1, -1): jnp.ones((batch, 64, 3))}
>>> flat = dict_to_flat(irreps, d)
>>> flat.shape
(32, 320)