segmented_polynomial#

cuequivariance_jax.segmented_polynomial(
polynomial,
inputs,
outputs_shape_dtype,
indices=None,
*,
method='',
math_dtype=None,
name=None,
precision='undefined',
)#

Compute a segmented polynomial.

Evaluates a segmented polynomial, which represents a mathematical operation composed of several tensor products. The function supports both JAX and CUDA implementations for maximum performance and flexibility.

Parameters:
  • polynomial (SegmentedPolynomial) – The segmented polynomial to compute.

  • inputs (list[Array]) – List of input buffers as JAX arrays.

  • outputs_shape_dtype (list[ShapeDtypeStruct]) – List of output shapes and dtypes specifications. The last shape dimension can be set to -1 to infer the size from the polynomial descriptor.

  • indices (None | list[None | Array | tuple[Array | slice]]) – Optional list of indices for inputs and outputs. If None, no indexing is applied. Defaults to None. Note that indices are not supported for all methods.

  • method (str) –

    Specifies the implementation method to use. Options are:

    • "naive": Uses a naive JAX implementation. It always works but is not optimized.

    • "uniform_1d": Uses a CUDA implementation for polynomials with a single uniform mode.

    • "indexed_linear": Uses a CUDA implementation for linear layers with indexed weights.

    Note

    The "fused_tp" method is only available in the PyTorch implementation.

  • math_dtype (str | None) –

    Data type for computational operations. If None, automatically determined from input types. Defaults to None.

    Supported options vary by method:

    • "naive": String dtype names (e.g., "float32", "float64", "float16", "bfloat16"). Also supports "tensor_float32" for TensorFloat-32 mode.

    • "uniform_1d": String values "float32" or "float64" only.

    • "indexed_linear": CUBLAS compute type strings such as "CUBLAS_COMPUTE_32F", "CUBLAS_COMPUTE_32F_FAST_TF32", "CUBLAS_COMPUTE_32F_PEDANTIC", "CUBLAS_COMPUTE_64F", etc.

  • name (str | None) – Optional name for the operation.

  • precision (Precision)

Returns:

List of JAX arrays containing the computed polynomial outputs.

Return type:

list[Array]

Notes

JAX Transformations Support:
  • Supports JAX transformations: jit, grad, jvp, vmap

  • Supports infinite derivatives through JVP and transpose rules

  • Full batching support

Examples

Simple example computing the spherical harmonics:

>>> p = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]).polynomial
>>> cuex.segmented_polynomial(
...     p, [jnp.array([0.0, 1.0, 0.0])], [jax.ShapeDtypeStruct((-1,), jnp.float32)], method="naive"
... )
[Array([1.       , 0.       , 1.7320508, 0.       , 0.       , 0.       ,
       2.236068 , 0.       , 0.       ], dtype=float32)]

Example computing a tensor product with indexing using the “uniform_1d” method:

>>> poly: cue.SegmentedPolynomial = cue.descriptors.channelwise_tensor_product(
...     cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e + 32x2o"),
...     cue.Irreps(cue.O3, "0e + 1o + 1e"),
...     cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e"),
... ).polynomial
>>> a = np.random.randn(1, 50, poly.inputs[0].size)
>>> b = np.random.randn(10, 50, poly.inputs[1].size)
>>> c = np.random.randn(100, 1, poly.inputs[2].size)
>>> i = np.random.randint(0, 10, (100, 50))
>>> D = jax.ShapeDtypeStruct(shape=(11, 12, poly.outputs[0].size), dtype=jnp.float32)
>>> j1 = np.random.randint(0, 11, (100, 50))
>>> j2 = np.random.randint(0, 12, (100, 1))
>>> [D] = cuex.segmented_polynomial(
...     poly, [a, b, c], [D], [None, np.s_[i, :], None, np.s_[j1, j2]], method="uniform_1d"
... )
>>> D.shape
(11, 12, 1056)

Example computing a linear layer with indexed weights using the “indexed_linear” method:

>>> input_irreps = cue.Irreps(cue.O3, "10x0e + 10x1o")
>>> output_irreps = cue.Irreps(cue.O3, "20x0e + 20x1o")
>>> poly = cue.descriptors.linear(input_irreps, output_irreps).polynomial
>>> counts = jnp.array([3, 4, 3], dtype=jnp.int32)  # Number of elements in each partition
>>> w = jax.random.normal(jax.random.key(0), (3, poly.inputs[0].size), dtype=jnp.float32)
>>> x = jax.random.normal(jax.random.key(1), (10, poly.inputs[1].size), dtype=jnp.float32)
>>> y = jax.ShapeDtypeStruct((10, poly.outputs[0].size), jnp.float32)
>>> [y] = cuex.segmented_polynomial(
...     poly,
...     [w, x], [y], [cuex.Repeats(counts), None, None],
...     method="indexed_linear"
... )
>>> y.shape
(10, 80)

Note

This operation uses a custom CUDA kernel for performance. When using this function on multiple devices, manual sharding is required to achieve proper performance. Without explicit sharding, performance will be significantly degraded. See JAX shard_map documentation for details on manual parallelism.