segmented_polynomial#
- cuequivariance_jax.segmented_polynomial(
- polynomial: SegmentedPolynomial,
- inputs: list[Array],
- outputs_shape_dtype: list[ShapeDtypeStruct],
- indices: None | list[None | Array | tuple[Array | slice]] = None,
- *,
- math_dtype: dtype | None = None,
- name: str | None = None,
- impl: str = 'auto',
Compute a segmented polynomial.
Evaluates a segmented polynomial, which represents a mathematical operation composed of several tensor products. The function supports both JAX and CUDA implementations for maximum performance and flexibility.
- Parameters:
polynomial – The segmented polynomial to compute.
inputs – List of input buffers as JAX arrays.
outputs_shape_dtype – List of output shapes and dtypes specifications. The last shape dimension can be set to -1 to infer the size from the polynomial descriptor.
indices – Optional list of indices for inputs and outputs. If None, no indexing is applied. Defaults to None.
math_dtype – Data type for computational operations. If None, automatically determined from input types, defaulting to float32 if no float64 inputs are present. Defaults to None.
name – Optional name for the operation. Defaults to None.
impl – Implementation to use, one of [“auto”, “cuda”, “jax”, “naive_jax”]. If “auto”, uses CUDA when available and efficient, falling back to JAX otherwise. Defaults to “auto”.
- Returns:
List of JAX arrays containing the computed polynomial outputs.
- Performance Considerations:
- CUDA acceleration requirements:
STPs have a single mode (e.g. channelwise tensor product with subscripts
u,u,,u
)Math data type is float32 or float64
Input/output data types are float32, float64, float16, or bfloat16
- Automatic optimizations:
Based on STP symmetries
Based on input buffer repetition patterns
Automatic pruning of unused buffers and indices
- Implementation Details:
- Supports JAX transformations: jit, grad, jvp, vmap
Supports infinite derivatives through JVP and transpose rules
Full batching support
Note
For maximum performance with CUDA-capable hardware, ensure inputs match the CUDA kernel activation conditions listed above. To verify wether the CUDA implementation is used, set
impl="cuda"
or setlogging.basicConfig(level=logging.INFO)
.Examples
Simple example with spherical harmonics:
>>> p = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]).polynomial >>> cuex.segmented_polynomial( ... p, [jnp.array([0.0, 1.0, 0.0])], [jax.ShapeDtypeStruct((-1,), jnp.float32)] ... ) [Array([1. , 0. , 1.7320508, 0. , 0. , 0. , 2.236068 , 0. , 0. ], dtype=float32)]
Advanced example with tensor product and indexing:
>>> poly: cue.SegmentedPolynomial = cue.descriptors.channelwise_tensor_product( ... cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e + 32x2o"), ... cue.Irreps(cue.O3, "0e + 1o + 1e"), ... cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e"), ... ).polynomial.flatten_coefficient_modes().squeeze_modes() >>> a = np.random.randn(1, 50, poly.inputs[0].size) >>> b = np.random.randn(10, 50, poly.inputs[1].size) >>> c = np.random.randn(100, 1, poly.inputs[2].size) >>> i = np.random.randint(0, 10, (100, 50)) >>> D = jax.ShapeDtypeStruct(shape=(11, 12, poly.outputs[0].size), dtype=np.float32) >>> j1 = np.random.randint(0, 11, (100, 50)) >>> j2 = np.random.randint(0, 12, (100, 1)) >>> [D] = cuex.segmented_polynomial( ... poly, [a, b, c], [D], [None, np.s_[i, :], None, np.s_[j1, j2]] ... ) >>> D.shape (11, 12, 1056)