IrrepsArray#
- class cuequivariance_jax.IrrepsArray( )#
Wrapper around a jax array with a dict of Irreps for the non-trivial axes.
Creation
>>> cuex.IrrepsArray( ... {-1: cue.Irreps("SO3", "2x0")}, jnp.array([1.0, 2.0]), cue.ir_mul ... ) {0: 2x0} [1. 2.]
If you don’t specify the axis it will default to the last axis:
>>> cuex.IrrepsArray( ... cue.Irreps("SO3", "2x0"), jnp.array([1.0, 2.0]), cue.ir_mul ... ) {0: 2x0} [1. 2.]
You can use a default group and layout:
>>> with cue.assume(cue.SO3, cue.ir_mul): ... cuex.IrrepsArray("2x0", jnp.array([1.0, 2.0])) {0: 2x0} [1. 2.]
Arithmetic
Basic arithmetic operations are supported, as long as they are equivariant:
>>> with cue.assume(cue.SO3, cue.ir_mul): ... x = cuex.IrrepsArray("2x0", jnp.array([1.0, 2.0])) ... y = cuex.IrrepsArray("2x0", jnp.array([3.0, 4.0])) ... x + y {0: 2x0} [4. 6.]
>>> 3.0 * x {0: 2x0} [3. 6.]
Attributes
- layout#
Data layout
- shape#
Shape of the array
- ndim#
Number of dimensions of the array
- dtype#
Data type of the array
Methods
- is_simple() bool #
Return True if the IrrepsArray has only the last axis non-trivial.
Examples
>>> cuex.IrrepsArray( ... cue.Irreps("SO3", "2x0"), jnp.array([1.0, 2.0]), cue.ir_mul ... ).is_simple() True
- irreps(
- axis: int = -1,
Return the Irreps for a given axis.
Examples
>>> cuex.IrrepsArray( ... cue.Irreps("SO3", "2x0"), jnp.array([1.0, 2.0]), cue.ir_mul ... ).irreps() 2x0
- slice_by_mul(
- axis: int = -1,
Return the slice with respect to the multiplicities.
Examples
>>> x = cuex.IrrepsArray( ... cue.Irreps("SO3", "2x0 + 1"), ... jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), cue.ir_mul ... ) >>> x.slice_by_mul()[1:4] {0: 0+1} [2. 0. 0. 0.]
- filter(
- *,
- keep: str | Sequence[Irrep] | Callable[[MulIrrep], bool] | None = None,
- drop: str | Sequence[Irrep] | Callable[[MulIrrep], bool] | None = None,
- mask: Sequence[bool] | None = None,
- axis: int = -1,
Filter the irreps.
- Parameters:
keep – Irreps to keep.
drop – Irreps to drop.
mask – Boolean mask for segments to keep.
axis – Axis to filter.
Examples
>>> x = cuex.IrrepsArray( ... cue.Irreps("SO3", "2x0 + 1"), ... jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), cue.ir_mul ... ) >>> x.filter(keep="0") {0: 2x0} [1. 2.] >>> x.filter(drop="0") {0: 1} [0. 0. 0.] >>> x.filter(mask=[True, False]) {0: 2x0} [1. 2.]
- sort(
- axis: int = -1,
Sort the irreps.
Examples
>>> x = cuex.IrrepsArray( ... cue.Irreps("SO3", "1 + 2x0"), ... jnp.array([1.0, 1.0, 1.0, 2.0, 3.0]), cue.ir_mul ... ) >>> x.sort() {0: 2x0+1} [2. 3. 1. 1. 1.]
- regroup(
- axis: int = -1,
Clean up the irreps.
Examples
>>> x = cuex.IrrepsArray( ... cue.Irreps("SO3", "0 + 1 + 0"), jnp.array([0., 1., 2., 3., -1.]), ... cue.ir_mul ... ) >>> x.regroup() {0: 2x0+1} [ 0. -1. 1. 2. 3.]
- segments(axis: int = -1) list[Array] #
Split the array into segments.
Examples
>>> x = cuex.IrrepsArray( ... cue.Irreps("SO3", "2x0 + 1"), jnp.array([1.0, 2.0, 0.0, 0.0, 0.0]), ... cue.ir_mul ... ) >>> x.segments() [Array(...), Array(...)]
Note
See also
cuex.from_segments
.