equivariant_polynomial#
- cuequivariance_jax.equivariant_polynomial(
- poly,
- inputs,
- outputs_shape_dtype=None,
- indices=None,
- *,
- method='',
- math_dtype=None,
- name=None,
- precision=Precision.HIGHEST,
Compute an equivariant polynomial.
Evaluates an equivariant polynomial, which represents a mathematical operation that respects group symmetries. This function is the equivariant wrapper around the
cuex.segmented_polynomial
function, providing type checking and handling for representation-aware arrays.- Parameters:
poly (EquivariantPolynomial) – The equivariant polynomial descriptor.
inputs (list[RepArray | Array]) – List of input
cuex.RepArray
.outputs_shape_dtype (list[ShapeDtypeStruct] | ShapeDtypeStruct | None) – Shape and dtype specifications for outputs. If None, inferred from inputs when possible. When output indices are provided, this must be specified. The last shape dimension can be set to -1 to infer the size from the polynomial descriptor.
indices (None | list[None | Array | tuple[Array | slice]]) – Optional list of indices for inputs and outputs. Length must match total number of operands (inputs + outputs). Use None for unindexed operands. Defaults to None. Note that indices are not supported for all methods.
method (str) – Method to use for computation. See
cuex.segmented_polynomial
for available methods.math_dtype (dtype | None) – Data type for computational operations. If None, automatically determined from input types. Defaults to None.
name (str | None) – Optional name for the operation. Defaults to None.
precision (Precision) – The precision to use for the computation. Defaults to HIGHEST. Note that precision is not supported for all methods.
- Returns:
cuex.RepArray
or list ofcuex.RepArray
- Return type:
Note
See
cuex.segmented_polynomial
for more details on the implementation and usage of the underlying CUDA and JAX implementations.Examples
Create and compute spherical harmonics of degree 0, 1, and 2:
>>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) >>> e ╭ a=1 -> B=0+1+2 │ []➜B[] ───────── num_paths=1 │ []·a[]➜B[] ───── num_paths=3 ╰─ []·a[]·a[]➜B[] ─ num_paths=11
Basic usage with single input:
>>> with cue.assume(cue.SO3, cue.ir_mul): ... x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0])) >>> cuex.equivariant_polynomial(e, [x], method="naive") {0: 0+1+2} [1. ... ]
Using indices:
>>> i_out = jnp.array([0, 1, 1], dtype=jnp.int32) >>> with cue.assume(cue.SO3, cue.ir_mul): ... x = cuex.RepArray("1", jnp.array([ ... [0.0, 1.0, 0.0], ... [0.0, 0.0, 1.0], ... [1.0, 0.0, 0.0], ... ])) >>> result = cuex.equivariant_polynomial( ... e, ... [x], ... jax.ShapeDtypeStruct((2, e.outputs[0].dim), jnp.float32), ... indices=[None, i_out], ... method="naive", ... ) >>> result {1: 0+1+2} [[ 1. ... ] [ 2. ... ]]