equivariant_polynomial#
- cuequivariance_jax.equivariant_polynomial(
- poly: EquivariantPolynomial,
- inputs: list[RepArray | Array],
- outputs_shape_dtype: list[ShapeDtypeStruct] | ShapeDtypeStruct | None = None,
- indices: None | list[None | Array | tuple[Array | slice]] = None,
- math_dtype: dtype | None = None,
- name: str | None = None,
- impl: str = 'auto',
Compute an equivariant polynomial.
Evaluates an equivariant polynomial, which represents a mathematical operation that respects group symmetries. This function is the equivariant wrapper around the
cuex.segmented_polynomial
function, providing type checking and handling for representation-aware arrays.- Parameters:
poly – The equivariant polynomial descriptor.
inputs – List of input
cuex.RepArray
.outputs_shape_dtype – Shape and dtype specifications for outputs. If None, inferred from inputs when possible. When output indices are provided, this must be specified. The last shape dimension can be set to -1 to infer the size from the polynomial descriptor.
indices – Optional list of indices for inputs and outputs. Length must match total number of operands (inputs + outputs). Use None for unindexed operands. Defaults to None.
math_dtype – Data type for computational operations. If None, automatically determined from input types. Defaults to None.
name – Optional name for the operation. Defaults to None.
impl – Implementation to use, one of [“auto”, “cuda”, “jax”, “naive_jax”]. If “auto”, uses CUDA when available and efficient, falling back to JAX otherwise. Defaults to “auto”.
- Returns:
cuex.RepArray
or list ofcuex.RepArray
Note
See
cuex.segmented_polynomial
for more details on the implementation and usage of the underlying CUDA and JAX implementations.Examples
Create and compute spherical harmonics of degree 0, 1, and 2:
>>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) >>> e ╭ a=1 -> B=0+1+2 │ []➜B[] ───────── num_paths=1 │ []·a[]➜B[] ───── num_paths=3 ╰─ []·a[]·a[]➜B[] ─ num_paths=11
Basic usage with single input:
>>> with cue.assume(cue.SO3, cue.ir_mul): ... x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0])) >>> cuex.equivariant_polynomial(e, [x]) {0: 0+1+2} [1. ... ]
Using indices:
>>> i_out = jnp.array([0, 1, 1], dtype=jnp.int32) >>> with cue.assume(cue.SO3, cue.ir_mul): ... x = cuex.RepArray("1", jnp.array([ ... [0.0, 1.0, 0.0], ... [0.0, 0.0, 1.0], ... [1.0, 0.0, 0.0], ... ])) >>> result = cuex.equivariant_polynomial( ... e, ... [x], ... jax.ShapeDtypeStruct((2, e.outputs[0].dim), jnp.float32), ... indices=[None, i_out], ... ) >>> result {1: 0+1+2} [[ 1. ... ] [ 2. ... ]]