segmented_polynomial#
- cuequivariance_jax.segmented_polynomial(
- polynomial,
- inputs,
- outputs_shape_dtype,
- indices=None,
- *,
- method='',
- math_dtype=None,
- name=None,
- precision=Precision.HIGHEST,
Compute a segmented polynomial.
Evaluates a segmented polynomial, which represents a mathematical operation composed of several tensor products. The function supports both JAX and CUDA implementations for maximum performance and flexibility.
- Parameters:
polynomial (SegmentedPolynomial) – The segmented polynomial to compute.
outputs_shape_dtype (list[ShapeDtypeStruct]) – List of output shapes and dtypes specifications. The last shape dimension can be set to -1 to infer the size from the polynomial descriptor.
indices (None | list[None | Array | tuple[Array | slice]]) – Optional list of indices for inputs and outputs. If None, no indexing is applied. Defaults to None. Note that indices are not supported for all methods.
method (str) –
Specifies the implementation method to use. Options are:
"naive"
: Uses a naive JAX implementation. It always works but is not optimized."uniform_1d"
: Uses a CUDA implementation for polynomials with a single uniform mode."indexed_linear"
: Uses a CUDA implementation for linear layers with indexed weights.
Note
The
"fused_tp"
method is only available in the PyTorch implementation.math_dtype (dtype | None) – Data type for computational operations. If None, automatically determined from input types, defaulting to float32 if no float64 inputs are present.
name (str | None) – Optional name for the operation.
precision (Precision) – The precision to use for the computation. Defaults to HIGHEST. Note that precision is only supported for the
"naive"
method.
- Returns:
List of JAX arrays containing the computed polynomial outputs.
- Return type:
Notes
- JAX Transformations Support:
Supports JAX transformations: jit, grad, jvp, vmap
Supports infinite derivatives through JVP and transpose rules
Full batching support
Examples
Simple example computing the spherical harmonics:
>>> p = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]).polynomial >>> cuex.segmented_polynomial( ... p, [jnp.array([0.0, 1.0, 0.0])], [jax.ShapeDtypeStruct((-1,), jnp.float32)], method="naive" ... ) [Array([1. , 0. , 1.7320508, 0. , 0. , 0. , 2.236068 , 0. , 0. ], dtype=float32)]
Example computing a tensor product with indexing using the “uniform_1d” method:
>>> poly: cue.SegmentedPolynomial = cue.descriptors.channelwise_tensor_product( ... cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e + 32x2o"), ... cue.Irreps(cue.O3, "0e + 1o + 1e"), ... cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e"), ... ).polynomial >>> a = np.random.randn(1, 50, poly.inputs[0].size) >>> b = np.random.randn(10, 50, poly.inputs[1].size) >>> c = np.random.randn(100, 1, poly.inputs[2].size) >>> i = np.random.randint(0, 10, (100, 50)) >>> D = jax.ShapeDtypeStruct(shape=(11, 12, poly.outputs[0].size), dtype=jnp.float32) >>> j1 = np.random.randint(0, 11, (100, 50)) >>> j2 = np.random.randint(0, 12, (100, 1)) >>> [D] = cuex.segmented_polynomial( ... poly, [a, b, c], [D], [None, np.s_[i, :], None, np.s_[j1, j2]], method="uniform_1d" ... ) >>> D.shape (11, 12, 1056)
Example computing a linear layer with indexed weights using the “indexed_linear” method:
>>> input_irreps = cue.Irreps(cue.O3, "10x0e + 10x1o") >>> output_irreps = cue.Irreps(cue.O3, "20x0e + 20x1o") >>> poly = cue.descriptors.linear(input_irreps, output_irreps).polynomial >>> counts = jnp.array([3, 4, 3], dtype=jnp.int32) # Number of elements in each partition >>> w = jax.random.normal(jax.random.key(0), (3, poly.inputs[0].size), dtype=jnp.float32) >>> x = jax.random.normal(jax.random.key(1), (10, poly.inputs[1].size), dtype=jnp.float32) >>> y = jax.ShapeDtypeStruct((10, poly.outputs[0].size), jnp.float32) >>> [y] = cuex.segmented_polynomial( ... poly, ... [w, x], [y], [cuex.Repeats(counts), None, None], ... method="indexed_linear" ... ) >>> y.shape (10, 80)