tensor_product#

cuequivariance_jax.tensor_product(
descriptors: list[tuple[Operation, SegmentedTensorProduct]],
inputs: list[Array],
outputs_shape_dtype: list[ShapeDtypeStruct],
indices: list[Array | None] | None = None,
*,
math_dtype: dtype | None = None,
name: str | None = None,
impl: str = 'auto',
) list[Array]#

Compute a polynomial described by a list of descriptors.

Features:
  • Calls a CUDA kernel if:
    • STPs have a single mode which is a multiple of 32 (e.g. a channelwise tensor product that has subscripts u,u,,u with u=128)

    • math data type is float32 or float64

    • in/out data type is a mix of float32, float64, float16 and bfloat16

    • indices are int32

  • Supports of infinite derivatives (JVP and tranpose rules maps to a single corresponding primitive)

  • Limited support for batching (we cannot batch a buffer that has indices and if the batching is non trivial the performace will be bad)

  • Automatic optimizations based on the symmetries of the STPs and on the repetition of the input buffers

  • Automatic drop of unused buffers and indices

Parameters:
  • descriptors (list of pairs) – The list of descriptors. Each descriptor is formed by a pair of cue.Operation and cue.SegmentedTensorProduct.

  • inputs (list of jax.Array) – The input buffers.

  • outputs_shape_dtype (list of jax.ShapeDtypeStruct) – The output shapes and dtypes.

  • indices (list of jax.Array or None, optional) – The optional indices of the inputs and outputs.

  • math_dtype (jnp.dtype, optional) – The data type for computational operations. Defaults to None.

  • name (str, optional) – The name of the operation. Defaults to None.

  • impl (str, optional) – The implementation to use. Defaults to “auto”. If “auto”, it will use the CUDA implementation if available, otherwise it will use the JAX implementation. If “cuda”, it will use the CUDA implementation. If “jax”, it will use the JAX implementation.

Returns:

The result of the tensor product.

Return type:

list of jax.Array