segmented_polynomial_uniform_1d#
- cuequivariance_jax.ir_dict.segmented_polynomial_uniform_1d(
- polynomial,
- inputs,
- outputs=None,
- input_indices=None,
- output_indices=None,
- *,
- math_dtype=None,
- name=None,
Execute a segmented polynomial with uniform 1D method on tree-structured inputs/outputs.
This function wraps cuex.segmented_polynomial with method=”uniform_1d”, handling the flattening/unflattening of pytree-structured inputs and outputs. It’s designed to work with dict[Irrep, Array] representations where each array has shape (…, num_segments, *segment_shape).
- Parameters:
polynomial (SegmentedPolynomial) – The segmented polynomial to execute.
inputs (Any) – Pytree of input arrays. Leaves must have shape (…, num_segments, *segment_shape) matching the polynomial’s input descriptors.
outputs (Any) – Pytree of output arrays or ShapeDtypeStruct, or None for default zeros. Must have shape (…, num_segments, *segment_shape) matching output descriptors.
input_indices (Any) – Pytree matching inputs structure with index arrays for gather operations, or None for no indexing. Broadcast to match inputs structure.
output_indices (Any) – Pytree matching outputs structure with index arrays for scatter operations, or None for no indexing. Broadcast to match outputs structure.
math_dtype (Any) – Optional dtype for internal computation.
name (str | None) – Optional name for profiling/debugging.
- Returns:
Pytree with same structure as outputs, containing computed results with shape (…, num_segments, *segment_shape).
- Return type:
Example
>>> # After split_operand_by_irrep, inputs/outputs are dict[Irrep, Array] >>> e = descriptor.split_operand_by_irrep(1).split_operand_by_irrep(-1) >>> p = e.polynomial >>> y = segmented_polynomial_uniform_1d( ... p, [w, x], y, ... input_indices=[None, senders], ... output_indices=receivers, ... )