segmented_polynomial#

cuequivariance_jax.segmented_polynomial(
polynomial: SegmentedPolynomial,
inputs: list[Array],
outputs_shape_dtype: list[ShapeDtypeStruct],
indices: None | list[None | Array | tuple[Array | slice]] = None,
*,
math_dtype: dtype | None = None,
name: str | None = None,
impl: str = 'auto',
) list[Array]#

Compute a segmented polynomial.

Evaluates a segmented polynomial, which represents a mathematical operation composed of several tensor products. The function supports both JAX and CUDA implementations for maximum performance and flexibility.

Parameters:
  • polynomial – The segmented polynomial to compute.

  • inputs – List of input buffers as JAX arrays.

  • outputs_shape_dtype – List of output shapes and dtypes specifications. The last shape dimension can be set to -1 to infer the size from the polynomial descriptor.

  • indices – Optional list of indices for inputs and outputs. If None, no indexing is applied. Defaults to None.

  • math_dtype – Data type for computational operations. If None, automatically determined from input types, defaulting to float32 if no float64 inputs are present. Defaults to None.

  • name – Optional name for the operation. Defaults to None.

  • impl – Implementation to use, one of [“auto”, “cuda”, “jax”, “naive_jax”]. If “auto”, uses CUDA when available and efficient, falling back to JAX otherwise. Defaults to “auto”.

Returns:

List of JAX arrays containing the computed polynomial outputs.

Performance Considerations:
  • CUDA acceleration requirements:
    • STPs have a single mode (e.g. channelwise tensor product with subscripts u,u,,u)

    • Math data type is float32 or float64

    • Input/output data types are float32, float64, float16, or bfloat16

  • Automatic optimizations:
    • Based on STP symmetries

    • Based on input buffer repetition patterns

    • Automatic pruning of unused buffers and indices

Implementation Details:
  • Supports JAX transformations: jit, grad, jvp, vmap
    • Supports infinite derivatives through JVP and transpose rules

    • Full batching support

Note

For maximum performance with CUDA-capable hardware, ensure inputs match the CUDA kernel activation conditions listed above. To verify wether the CUDA implementation is used, set impl="cuda" or set logging.basicConfig(level=logging.INFO).

Examples

Simple example with spherical harmonics:

>>> p = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]).polynomial
>>> cuex.segmented_polynomial(
...     p, [jnp.array([0.0, 1.0, 0.0])], [jax.ShapeDtypeStruct((-1,), jnp.float32)]
... )
[Array([1.       , 0.       , 1.7320508, 0.       , 0.       , 0.       ,
       2.236068 , 0.       , 0.       ], dtype=float32)]

Advanced example with tensor product and indexing:

>>> poly: cue.SegmentedPolynomial = cue.descriptors.channelwise_tensor_product(
...     cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e + 32x2o"),
...     cue.Irreps(cue.O3, "0e + 1o + 1e"),
...     cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e"),
... ).polynomial.flatten_coefficient_modes().squeeze_modes()
>>> a = np.random.randn(1, 50, poly.inputs[0].size)
>>> b = np.random.randn(10, 50, poly.inputs[1].size)
>>> c = np.random.randn(100, 1, poly.inputs[2].size)
>>> i = np.random.randint(0, 10, (100, 50))
>>> D = jax.ShapeDtypeStruct(shape=(11, 12, poly.outputs[0].size), dtype=np.float32)
>>> j1 = np.random.randint(0, 11, (100, 50))
>>> j2 = np.random.randint(0, 12, (100, 1))
>>> [D] = cuex.segmented_polynomial(
...     poly, [a, b, c], [D], [None, np.s_[i, :], None, np.s_[j1, j2]]
... )
>>> D.shape
(11, 12, 1056)