symmetric_tensor_product#

cuequivariance_jax.symmetric_tensor_product(
ds: list[SegmentedTensorProduct],
*inputs: Array,
dtype_output: dtype | None = None,
dtype_math: dtype | None = None,
precision: Precision = Precision.HIGHEST,
algorithm: str = 'sliced',
use_custom_primitive: bool = True,
use_custom_kernels: bool = False,
) Array#

Compute the sum of the STPs evaluated on the input (all input operands are the same).

Parameters:
  • ds (list[stp.SegmentedTensorProduct]) – The segmented tensor product descriptors.

  • *inputs (jax.Array) – The input arrays. The last input is repeated to match the number of input operands of each STP.

  • dtype_output (jnp.dtype, optional) – The data type for the output array.

  • dtype_math (jnp.dtype, optional) – The data type for mathematical operations.

  • precision (jax.lax.Precision, optional) – The precision for the computation. Defaults to jax.lax.Precision.HIGHEST.

  • algorithm (str, optional) – One of “sliced”, “stacked”, “compact_stacked”, “indexed_compact”, “indexed_vmap”, “indexed_for_loop”. Defaults to “sliced”.

  • use_custom_primitive (bool, optional) – Whether to use custom JVP rules. Defaults to True.

  • use_custom_kernels (bool, optional) – Whether to use custom kernels. Defaults to True.

Returns:

The result of the tensor product computation.

Return type:

jax.Array

See also

SegmentedTensorProduct