symmetric_tensor_product#
- cuequivariance_jax.symmetric_tensor_product(
- ds: list[SegmentedTensorProduct],
- *inputs: Array,
- dtype_output: dtype | None = None,
- dtype_math: dtype | None = None,
- precision: Precision = Precision.HIGHEST,
- algorithm: str = 'sliced',
- use_custom_primitive: bool = True,
- use_custom_kernels: bool = False,
Compute the sum of the STPs evaluated on the input (all input operands are the same).
- Parameters:
ds (list[stp.SegmentedTensorProduct]) – The segmented tensor product descriptors.
*inputs (jax.Array) – The input arrays. The last input is repeated to match the number of input operands of each STP.
dtype_output (jnp.dtype, optional) – The data type for the output array.
dtype_math (jnp.dtype, optional) – The data type for mathematical operations.
precision (jax.lax.Precision, optional) – The precision for the computation. Defaults to jax.lax.Precision.HIGHEST.
algorithm (str, optional) – One of “sliced”, “stacked”, “compact_stacked”, “indexed_compact”, “indexed_vmap”, “indexed_for_loop”. Defaults to “sliced”.
use_custom_primitive (bool, optional) – Whether to use custom JVP rules. Defaults to True.
use_custom_kernels (bool, optional) – Whether to use custom kernels. Defaults to True.
- Returns:
The result of the tensor product computation.
- Return type:
See also
SegmentedTensorProduct