cuequivariance-jax#

RepArray#

RepArray

A jax.Array decorated with a dict of cue.Rep for the axes transforming under a group representation.

from_segments(irreps, segments, shape[, ...])

Construct a RepArray from segments.

as_irreps_array(input[, layout, like])

Converts input to a RepArray.

concatenate(arrays)

Concatenate a list of cuex.RepArray

randn(key, rep[, leading_shape, dtype])

Generate a random cuex.RepArray.

Tensor Products#

equivariant_polynomial(poly, inputs[, ...])

Compute an equivariant polynomial.

segmented_polynomial(polynomial, inputs, ...)

Compute a segmented polynomial.

ir_dict#

Utilities for working with dict[Irrep, Array] representation, an alternative to RepArray.

ir_dict.segmented_polynomial_uniform_1d(...)

Execute a segmented polynomial with uniform 1D method on tree-structured inputs/outputs.

ir_dict.assert_mul_ir_dict(irreps, x)

Assert that a dict[Irrep, Array] matches the expected irreps structure.

ir_dict.mul_ir_dict(irreps, data)

Create a dict[Irrep, data] by broadcasting data to match irreps structure.

ir_dict.flat_to_dict(irreps, data, *[, layout])

Convert a flat array to dict[Irrep, Array] with shape (..., mul, ir.dim).

ir_dict.dict_to_flat(irreps, x)

Convert dict[Irrep, Array] back to a flat contiguous array.

ir_dict.irreps_add(x, y)

Element-wise addition of two dict[Irrep, Array] representations.

ir_dict.irreps_zeros_like(x)

Create a dict[Irrep, Array] of zeros with the same structure.

NNX Layers#

Flax NNX modules using dict[Irrep, Array] representation.

Extra Modules#

flax_linen.LayerNorm

LayerNorm(epsilon: float = 0.01, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7ff21a68fcb0>, name: Optional[str] = None)

spherical_harmonics(ls, vector[, normalize])

Compute the spherical harmonics of a vector.

Utilities#

Repeats

A class to represent a sequence of repeated elements.

Triangle#

triangle_multiplicative_update(x[, ...])

Apply triangle multiplicative update operation.

triangle_attention(q, k, v, bias, mask, scale)

triangle attention

Experimental#

experimental.indexed_linear(poly, counts, w, x)