cuequivariance-jax#

RepArray#

RepArray(reps, array[, layout])

A jax.Array decorated with a dict of cue.Rep for the axes transforming under a group representation.

from_segments(irreps, segments, shape[, ...])

Construct a RepArray from segments.

as_irreps_array(input[, layout, like])

Converts input to a RepArray.

concatenate(arrays)

Concatenate a list of cuex.RepArray

randn(key, rep[, leading_shape, dtype])

Generate a random cuex.RepArray.

Tensor Products#

equivariant_polynomial(poly, inputs[, ...])

Compute an equivariant polynomial.

segmented_polynomial(polynomial, inputs, ...)

Compute a segmented polynomial.

Extra Modules#

flax_linen.LayerNorm(epsilon, parent, ...)

spherical_harmonics(ls, vector[, normalize])

Compute the spherical harmonics of a vector.