cuequivariance-jax#

RepArray#

RepArray

A jax.Array decorated with a dict of cue.Rep for the axes transforming under a group representation.

from_segments(irreps, segments, shape[, ...])

Construct a RepArray from segments.

as_irreps_array(input[, layout, like])

Converts input to a RepArray.

concatenate(arrays)

Concatenate a list of cuex.RepArray

randn(key, rep[, leading_shape, dtype])

Generate a random cuex.RepArray.

Tensor Products#

equivariant_polynomial(poly, inputs[, ...])

Compute an equivariant polynomial.

segmented_polynomial(polynomial, inputs, ...)

Compute a segmented polynomial.

Extra Modules#

flax_linen.LayerNorm

LayerNorm(epsilon: float = 0.01, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x79122c72d670>, name: Optional[str] = None)

spherical_harmonics(ls, vector[, normalize])

Compute the spherical harmonics of a vector.

Triangle#

triangle_multiplicative_update(x[, ...])

Apply triangle multiplicative update operation.

triangle_attention(q, k, v, bias, mask, scale)

triangle attention

Experimental#

experimental.indexed_linear(poly, counts, w, x)

Linear layer with different weights for different parts of the input.