equivariant_polynomial#

cuequivariance_jax.equivariant_polynomial(
poly: EquivariantPolynomial,
inputs: list[RepArray | Array],
outputs_shape_dtype: list[ShapeDtypeStruct] | ShapeDtypeStruct | None = None,
indices: None | list[None | Array | tuple[Array | slice]] = None,
math_dtype: dtype | None = None,
name: str | None = None,
impl: str = 'auto',
) list[RepArray] | RepArray#

Compute an equivariant polynomial.

Evaluates an equivariant polynomial, which represents a mathematical operation that respects group symmetries. This function is the equivariant wrapper around the cuex.segmented_polynomial function, providing type checking and handling for representation-aware arrays.

Parameters:
  • poly – The equivariant polynomial descriptor.

  • inputs – List of input cuex.RepArray.

  • outputs_shape_dtype – Shape and dtype specifications for outputs. If None, inferred from inputs when possible. When output indices are provided, this must be specified. The last shape dimension can be set to -1 to infer the size from the polynomial descriptor.

  • indices – Optional list of indices for inputs and outputs. Length must match total number of operands (inputs + outputs). Use None for unindexed operands. Defaults to None.

  • math_dtype – Data type for computational operations. If None, automatically determined from input types. Defaults to None.

  • name – Optional name for the operation. Defaults to None.

  • impl – Implementation to use, one of [“auto”, “cuda”, “jax”, “naive_jax”]. If “auto”, uses CUDA when available and efficient, falling back to JAX otherwise. Defaults to “auto”.

Returns:

cuex.RepArray or list of cuex.RepArray

Note

See cuex.segmented_polynomial for more details on the implementation and usage of the underlying CUDA and JAX implementations.

Examples

Create and compute spherical harmonics of degree 0, 1, and 2:

>>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2])
>>> e
╭ a=1 -> B=0+1+2
│  []➜B[] ───────── num_paths=1
│  []·a[]➜B[] ───── num_paths=3
╰─ []·a[]·a[]➜B[] ─ num_paths=11

Basic usage with single input:

>>> with cue.assume(cue.SO3, cue.ir_mul):
...    x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0]))
>>> cuex.equivariant_polynomial(e, [x])
{0: 0+1+2}
[1. ... ]

Using indices:

>>> i_out = jnp.array([0, 1, 1], dtype=jnp.int32)
>>> with cue.assume(cue.SO3, cue.ir_mul):
...    x = cuex.RepArray("1", jnp.array([
...         [0.0, 1.0, 0.0],
...         [0.0, 0.0, 1.0],
...         [1.0, 0.0, 0.0],
...    ]))
>>> result = cuex.equivariant_polynomial(
...   e,
...   [x],
...   jax.ShapeDtypeStruct((2, e.outputs[0].dim), jnp.float32),
...   indices=[None, i_out],
... )
>>> result
{1: 0+1+2}
[[ 1. ... ]
 [ 2. ... ]]