IrrepsArray#

class cuequivariance_jax.IrrepsArray(
irreps: Irreps | str | dict[int, Irreps | str],
array: Array,
layout: IrrepsLayout | None = None,
)#

Wrapper around a jax array with a dict of Irreps for the non-trivial axes.

Creation

>>> cuex.IrrepsArray(
...     {-1: cue.Irreps("SO3", "2x0")}, jnp.array([1.0, 2.0]), cue.ir_mul
... )
{0: 2x0} [1. 2.]

If you don’t specify the axis it will default to the last axis:

>>> cuex.IrrepsArray(
...     cue.Irreps("SO3", "2x0"), jnp.array([1.0, 2.0]), cue.ir_mul
... )
{0: 2x0} [1. 2.]

You can use a default group and layout:

>>> with cue.assume(cue.SO3, cue.ir_mul):
...     cuex.IrrepsArray("2x0", jnp.array([1.0, 2.0]))
{0: 2x0} [1. 2.]

Arithmetic

Basic arithmetic operations are supported, as long as they are equivariant:

>>> with cue.assume(cue.SO3, cue.ir_mul):
...     x = cuex.IrrepsArray("2x0", jnp.array([1.0, 2.0]))
...     y = cuex.IrrepsArray("2x0", jnp.array([3.0, 4.0]))
...     x + y
{0: 2x0} [4. 6.]
>>> 3.0 * x
{0: 2x0} [3. 6.]

Attributes

dirreps#

Irreps for the non-trivial axes, see also irreps() below.

Type:

dict[int, cuequivariance.irreps_array.irreps.Irreps]

array#

JAX array

Type:

jax.Array

layout#

Data layout

Type:

cuequivariance.irreps_array.irreps_layout.IrrepsLayout

shape#

Shape of the array

ndim#

Number of dimensions of the array

dtype#

Data type of the array

Methods

is_simple() bool#

Return True if the IrrepsArray has only the last axis non-trivial.

Examples

>>> cuex.IrrepsArray(
...     cue.Irreps("SO3", "2x0"), jnp.array([1.0, 2.0]), cue.ir_mul
... ).is_simple()
True
irreps(
axis: int = -1,
) Irreps#

Return the Irreps for a given axis.

Examples

>>> cuex.IrrepsArray(
...     cue.Irreps("SO3", "2x0"), jnp.array([1.0, 2.0]), cue.ir_mul
... ).irreps()
2x0
slice_by_mul(
axis: int = -1,
) _MulIndexSliceHelper#

Return the slice with respect to the multiplicities.

Examples

>>> x = cuex.IrrepsArray(
...     cue.Irreps("SO3", "2x0 + 1"),
...     jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), cue.ir_mul
... )
>>> x.slice_by_mul()[1:4]
{0: 0+1} [2. 0. 0. 0.]
filter(
*,
keep: str | Sequence[Irrep] | Callable[[MulIrrep], bool] | None = None,
drop: str | Sequence[Irrep] | Callable[[MulIrrep], bool] | None = None,
mask: Sequence[bool] | None = None,
axis: int = -1,
) IrrepsArray#

Filter the irreps.

Parameters:
  • keep – Irreps to keep.

  • drop – Irreps to drop.

  • mask – Boolean mask for segments to keep.

  • axis – Axis to filter.

Examples

>>> x = cuex.IrrepsArray(
...     cue.Irreps("SO3", "2x0 + 1"),
...     jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), cue.ir_mul
... )
>>> x.filter(keep="0")
{0: 2x0} [1. 2.]
>>> x.filter(drop="0")
{0: 1} [0. 0. 0.]
>>> x.filter(mask=[True, False])
{0: 2x0} [1. 2.]
sort(
axis: int = -1,
) IrrepsArray#

Sort the irreps.

Examples

>>> x = cuex.IrrepsArray(
...     cue.Irreps("SO3", "1 + 2x0"),
...     jnp.array([1.0, 1.0, 1.0, 2.0, 3.0]), cue.ir_mul
... )
>>> x.sort()
{0: 2x0+1} [2. 3. 1. 1. 1.]
regroup(
axis: int = -1,
) IrrepsArray#

Clean up the irreps.

Examples

>>> x = cuex.IrrepsArray(
...     cue.Irreps("SO3", "0 + 1 + 0"), jnp.array([0., 1., 2., 3., -1.]),
...     cue.ir_mul
... )
>>> x.regroup()
{0: 2x0+1} [ 0. -1.  1.  2.  3.]
segments(axis: int = -1) list[Array]#

Split the array into segments.

Examples

>>> x = cuex.IrrepsArray(
...     cue.Irreps("SO3", "2x0 + 1"), jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]),
...     cue.ir_mul
... )
>>> x.segments()
[Array(...), Array(...)]

Note

See also cuex.from_segments.