segmented_polynomial#

cuequivariance_jax.segmented_polynomial(
polynomial,
inputs,
outputs_shape_dtype,
indices=None,
*,
method='',
math_dtype=None,
name=None,
precision=Precision.HIGHEST,
)#

Compute a segmented polynomial.

Evaluates a segmented polynomial, which represents a mathematical operation composed of several tensor products. The function supports both JAX and CUDA implementations for maximum performance and flexibility.

Parameters:
  • polynomial (SegmentedPolynomial) – The segmented polynomial to compute.

  • inputs (list[Array]) – List of input buffers as JAX arrays.

  • outputs_shape_dtype (list[ShapeDtypeStruct]) – List of output shapes and dtypes specifications. The last shape dimension can be set to -1 to infer the size from the polynomial descriptor.

  • indices (None | list[None | Array | tuple[Array | slice]]) – Optional list of indices for inputs and outputs. If None, no indexing is applied. Defaults to None. Note that indices are not supported for all methods.

  • method (str) –

    Specifies the implementation method to use. Options are:

    • "naive": Uses a naive JAX implementation. It always works but is not optimized.

    • "uniform_1d": Uses a CUDA implementation for polynomials with a single uniform mode.

    • "indexed_linear": Uses a CUDA implementation for linear layers with indexed weights.

    Note

    The "fused_tp" method is only available in the PyTorch implementation.

  • math_dtype (dtype | None) – Data type for computational operations. If None, automatically determined from input types, defaulting to float32 if no float64 inputs are present.

  • name (str | None) – Optional name for the operation.

  • precision (Precision) – The precision to use for the computation. Defaults to HIGHEST. Note that precision is only supported for the "naive" method.

Returns:

List of JAX arrays containing the computed polynomial outputs.

Return type:

list[Array]

Notes

JAX Transformations Support:
  • Supports JAX transformations: jit, grad, jvp, vmap

  • Supports infinite derivatives through JVP and transpose rules

  • Full batching support

Examples

Simple example computing the spherical harmonics:

>>> p = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]).polynomial
>>> cuex.segmented_polynomial(
...     p, [jnp.array([0.0, 1.0, 0.0])], [jax.ShapeDtypeStruct((-1,), jnp.float32)], method="naive"
... )
[Array([1.       , 0.       , 1.7320508, 0.       , 0.       , 0.       ,
       2.236068 , 0.       , 0.       ], dtype=float32)]

Example computing a tensor product with indexing using the “uniform_1d” method:

>>> poly: cue.SegmentedPolynomial = cue.descriptors.channelwise_tensor_product(
...     cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e + 32x2o"),
...     cue.Irreps(cue.O3, "0e + 1o + 1e"),
...     cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e"),
... ).polynomial
>>> a = np.random.randn(1, 50, poly.inputs[0].size)
>>> b = np.random.randn(10, 50, poly.inputs[1].size)
>>> c = np.random.randn(100, 1, poly.inputs[2].size)
>>> i = np.random.randint(0, 10, (100, 50))
>>> D = jax.ShapeDtypeStruct(shape=(11, 12, poly.outputs[0].size), dtype=jnp.float32)
>>> j1 = np.random.randint(0, 11, (100, 50))
>>> j2 = np.random.randint(0, 12, (100, 1))
>>> [D] = cuex.segmented_polynomial(
...     poly, [a, b, c], [D], [None, np.s_[i, :], None, np.s_[j1, j2]], method="uniform_1d"
... )
>>> D.shape
(11, 12, 1056)

Example computing a linear layer with indexed weights using the “indexed_linear” method:

>>> input_irreps = cue.Irreps(cue.O3, "10x0e + 10x1o")
>>> output_irreps = cue.Irreps(cue.O3, "20x0e + 20x1o")
>>> poly = cue.descriptors.linear(input_irreps, output_irreps).polynomial
>>> counts = jnp.array([3, 4, 3], dtype=jnp.int32)  # Number of elements in each partition
>>> w = jax.random.normal(jax.random.key(0), (3, poly.inputs[0].size), dtype=jnp.float32)
>>> x = jax.random.normal(jax.random.key(1), (10, poly.inputs[1].size), dtype=jnp.float32)
>>> y = jax.ShapeDtypeStruct((10, poly.outputs[0].size), jnp.float32)
>>> [y] = cuex.segmented_polynomial(
...     poly,
...     [w, x], [y], [cuex.Repeats(counts), None, None],
...     method="indexed_linear"
... )
>>> y.shape
(10, 80)