equivariant_polynomial#

cuequivariance_jax.equivariant_polynomial(
poly,
inputs,
outputs_shape_dtype=None,
indices=None,
*,
method='',
math_dtype=None,
name=None,
precision=Precision.HIGHEST,
)#

Compute an equivariant polynomial.

Evaluates an equivariant polynomial, which represents a mathematical operation that respects group symmetries. This function is the equivariant wrapper around the cuex.segmented_polynomial function, providing type checking and handling for representation-aware arrays.

Parameters:
  • poly (EquivariantPolynomial) – The equivariant polynomial descriptor.

  • inputs (list[RepArray | Array]) – List of input cuex.RepArray.

  • outputs_shape_dtype (list[ShapeDtypeStruct] | ShapeDtypeStruct | None) – Shape and dtype specifications for outputs. If None, inferred from inputs when possible. When output indices are provided, this must be specified. The last shape dimension can be set to -1 to infer the size from the polynomial descriptor.

  • indices (None | list[None | Array | tuple[Array | slice]]) – Optional list of indices for inputs and outputs. Length must match total number of operands (inputs + outputs). Use None for unindexed operands. Defaults to None. Note that indices are not supported for all methods.

  • method (str) – Method to use for computation. See cuex.segmented_polynomial for available methods.

  • math_dtype (dtype | None) – Data type for computational operations. If None, automatically determined from input types. Defaults to None.

  • name (str | None) – Optional name for the operation. Defaults to None.

  • precision (Precision) – The precision to use for the computation. Defaults to HIGHEST. Note that precision is not supported for all methods.

Returns:

cuex.RepArray or list of cuex.RepArray

Return type:

list[RepArray] | RepArray

Note

See cuex.segmented_polynomial for more details on the implementation and usage of the underlying CUDA and JAX implementations.

Examples

Create and compute spherical harmonics of degree 0, 1, and 2:

>>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2])
>>> e
╭ a=1 -> B=0+1+2
│  []➜B[] ───────── num_paths=1
│  []·a[]➜B[] ───── num_paths=3
╰─ []·a[]·a[]➜B[] ─ num_paths=11

Basic usage with single input:

>>> with cue.assume(cue.SO3, cue.ir_mul):
...    x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0]))
>>> cuex.equivariant_polynomial(e, [x], method="naive")
{0: 0+1+2}
[1. ... ]

Using indices:

>>> i_out = jnp.array([0, 1, 1], dtype=jnp.int32)
>>> with cue.assume(cue.SO3, cue.ir_mul):
...    x = cuex.RepArray("1", jnp.array([
...         [0.0, 1.0, 0.0],
...         [0.0, 0.0, 1.0],
...         [1.0, 0.0, 0.0],
...    ]))
>>> result = cuex.equivariant_polynomial(
...   e,
...   [x],
...   jax.ShapeDtypeStruct((2, e.outputs[0].dim), jnp.float32),
...   indices=[None, i_out],
...   method="naive",
... )
>>> result
{1: 0+1+2}
[[ 1. ... ]
 [ 2. ... ]]