segmented_polynomial_uniform_1d#

cuequivariance_jax.ir_dict.segmented_polynomial_uniform_1d(
polynomial,
inputs,
outputs=None,
input_indices=None,
output_indices=None,
*,
math_dtype=None,
name=None,
)#

Execute a segmented polynomial with uniform 1D method on tree-structured inputs/outputs.

This function wraps cuex.segmented_polynomial with method=”uniform_1d”, handling the flattening/unflattening of pytree-structured inputs and outputs. It’s designed to work with dict[Irrep, Array] representations where each array has shape (…, num_segments, *segment_shape).

Parameters:
  • polynomial (SegmentedPolynomial) – The segmented polynomial to execute.

  • inputs (Any) – Pytree of input arrays. Leaves must have shape (…, num_segments, *segment_shape) matching the polynomial’s input descriptors.

  • outputs (Any) – Pytree of output arrays or ShapeDtypeStruct, or None for default zeros. Must have shape (…, num_segments, *segment_shape) matching output descriptors.

  • input_indices (Any) – Pytree matching inputs structure with index arrays for gather operations, or None for no indexing. Broadcast to match inputs structure.

  • output_indices (Any) – Pytree matching outputs structure with index arrays for scatter operations, or None for no indexing. Broadcast to match outputs structure.

  • math_dtype (Any) – Optional dtype for internal computation.

  • name (str | None) – Optional name for profiling/debugging.

Returns:

Pytree with same structure as outputs, containing computed results with shape (…, num_segments, *segment_shape).

Return type:

Any

Example

>>> # After split_operand_by_irrep, inputs/outputs are dict[Irrep, Array]
>>> e = descriptor.split_operand_by_irrep(1).split_operand_by_irrep(-1)
>>> p = e.polynomial
>>> y = segmented_polynomial_uniform_1d(
...     p, [w, x], y,
...     input_indices=[None, senders],
...     output_indices=receivers,
... )