RepArray#

class cuequivariance_jax.RepArray(
reps: dict[int, Rep] | Rep | Irreps | str | dict[int, Irreps] | dict[int, str],
array: Array,
layout: IrrepsLayout | None = None,
)#

A jax.Array decorated with a dict of cue.Rep for the axes transforming under a group representation.

Example:

You can create a cuex.RepArray by specifying the cue.Rep for each axis:

>>> cuex.RepArray({0: cue.SO3(1), 1: cue.SO3(1)}, jnp.eye(3))
{0: 1, 1: 1}
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]

By default, arguments that are not cue.Rep will be automatically converted into cue.IrrepsAndLayout:

>>> with cue.assume(cue.SO3, cue.ir_mul):
...     x = cuex.RepArray({0: "1", 1: "2"}, jnp.ones((3, 5)))
>>> x
{0: 1, 1: 2}
[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
>>> x.rep(0).irreps, x.rep(0).layout
(1, (irrep,mul))

IrrepsArray

An IrrepsArray is just a special case of a RepArray where the last axis is a cue.IrrepsAndLayout:

>>> x = cuex.RepArray(
...     cue.Irreps("SO3", "2x0"), jnp.zeros((3, 2)), cue.ir_mul
... )
>>> x
{1: 2x0}
[[0. 0.]
 [0. 0.]
 [0. 0.]]
>>> x.is_irreps_array()
True

You can use a default group and layout:

>>> with cue.assume(cue.SO3, cue.ir_mul):
...     cuex.RepArray("2x0", jnp.array([1.0, 2.0]))
{0: 2x0} [1. 2.]

Arithmetic

Basic arithmetic operations are supported, as long as they are equivariant:

>>> with cue.assume(cue.SO3, cue.ir_mul):
...     x = cuex.RepArray("2x0", jnp.array([1.0, 2.0]))
...     y = cuex.RepArray("2x0", jnp.array([3.0, 4.0]))
...     x + y
{0: 2x0} [4. 6.]
>>> 3.0 * x
{0: 2x0} [3. 6.]
property shape: tuple[int, ...]#

Shape of the array.

property ndim: int#

Number of dimensions of the array.

property dtype: dtype#

Data type of the array.

is_irreps_array() bool#

Check if the RepArray is an IrrepsArray.

An IrrepsArray is a RepArray where the last axis is an IrrepsAndLayout.

rep(axis: int) Rep#

Return the Rep for a given axis.

property irreps: Irreps#

Return the Irreps of the IrrepsArray.

Note

This method is only available for IrrepsArray. See is_irreps_array.

property layout: IrrepsLayout#

Return the layout of the IrrepsArray.

Note

This method is only available for IrrepsArray. See is_irreps_array.

property slice_by_mul: _MulIndexSliceHelper#

Return the slice with respect to the multiplicities.

Examples

>>> x = cuex.RepArray(
...     cue.Irreps("SO3", "2x0 + 1"),
...     jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]),
...     cue.ir_mul
... )
>>> x.slice_by_mul[1:4]
{0: 0+1} [2. 0. 0. 0.]

Note

This method is only available for IrrepsArray. See is_irreps_array.

transform(
v: Array,
) RepArray#

Transform the array according to the representation.

Parameters:

v – Vector of angles.

Examples

>>> x = cuex.RepArray(
...     {0: cue.SO3(1), 1: cue.SO3(1)}, jnp.ones((3, 3))
... )
>>> x
{0: 1, 1: 1}
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
>>> x.transform(jnp.array([np.pi, 0.0, 0.0])).array.round(1)
Array([[ 1., -1., -1.],
       [-1.,  1.,  1.],
       [-1.,  1.,  1.]]...)
property segments: list[Array]#

Split the array into segments.

Examples

>>> x = cuex.RepArray(
...     cue.Irreps("SO3", "2x0 + 1"), jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]),
...     cue.ir_mul
... )
>>> x.segments
[Array(...), Array(...)]

Note

This method is only available for IrrepsArray. See is_irreps_array.

filter(
*,
keep: str | Sequence[Irrep] | Callable[[MulIrrep], bool] | None = None,
drop: str | Sequence[Irrep] | Callable[[MulIrrep], bool] | None = None,
mask: Sequence[bool] | None = None,
) RepArray#

Filter the irreps.

Parameters:
  • keep – Irreps to keep.

  • drop – Irreps to drop.

  • mask – Boolean mask for segments to keep.

  • axis – Axis to filter.

Examples

>>> x = cuex.RepArray(
...     cue.Irreps("SO3", "2x0 + 1"),
...     jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), cue.ir_mul
... )
>>> x.filter(keep="0")
{0: 2x0} [1. 2.]
>>> x.filter(drop="0")
{0: 1} [0. 0. 0.]
>>> x.filter(mask=[True, False])
{0: 2x0} [1. 2.]

Note

This method is only available for IrrepsArray. See is_irreps_array.

sort() RepArray#

Sort the irreps.

Examples

>>> x = cuex.RepArray(
...     cue.Irreps("SO3", "1 + 2x0"),
...     jnp.array([1.0, 1.0, 1.0, 2.0, 3.0]), cue.ir_mul
... )
>>> x.sort()
{0: 2x0+1} [2. 3. 1. 1. 1.]

Note

This method is only available for IrrepsArray. See is_irreps_array.

regroup() RepArray#

Clean up the irreps.

Examples

>>> x = cuex.RepArray(
...     cue.Irreps("SO3", "0 + 1 + 0"), jnp.array([0., 1., 2., 3., -1.]),
...     cue.ir_mul
... )
>>> x.regroup()
{0: 2x0+1} [ 0. -1.  1.  2.  3.]

Note

This method is only available for IrrepsArray. See is_irreps_array.

change_layout(
layout: IrrepsLayout,
) RepArray#

Change the layout of the IrrepsArray.

Note

This method is only available for IrrepsArray. See is_irreps_array.

move_axis_to_mul(
axis: int,
) RepArray#

Move an axis to the multiplicities.

Note

This method is only available for IrrepsArray. See is_irreps_array.