equivariant_tensor_product#
- cuequivariance_jax.equivariant_tensor_product(
- e: EquivariantTensorProduct,
- *inputs: RepArray | Array,
- indices: list[Array | None] | None = None,
- output_batch_shape: tuple[int, ...] | None = None,
- output_dtype: dtype | None = None,
- math_dtype: dtype | None = None,
- name: str | None = None,
- impl: str = 'auto',
Compute the equivariant tensor product of the input arrays.
- Parameters:
e (
cue.EquivariantTensorProduct
) – The equivariant tensor product descriptor.indices (list of jax.Array or None, optional) – The optional indices of the inputs and output.
output_batch_shape (tuple of int, optional) – The batch shape of the output array.
output_dtype (jnp.dtype, optional) – The data type for the output array. Defaults to None.
math_dtype (jnp.dtype, optional) – The data type for computational operations. Defaults to None.
name (str, optional) – The name of the operation. Defaults to None.
- Returns:
The result of the equivariant tensor product.
- Return type:
Examples
Let’s create a descriptor for the spherical harmonics of degree 0, 1, and 2.
>>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) >>> e EquivariantTensorProduct((1)^(0..2) -> 0+1+2)
We need some input data.
>>> with cue.assume(cue.SO3, cue.ir_mul): ... x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0])) >>> x {0: 1} [0. 1. 0.]
Now we can execute the equivariant tensor product.
>>> cuex.equivariant_tensor_product(e, x) {0: 0+1+2} [1. ... ]
The indices argument allows to specify a list of optional int32 arrays for each input and for the output (None means no index and indices[-1] is the output index). The indices are used to select the elements of the input arrays and to specify the output index. In the following example, we will index the output. The input has a batch shape of (3,) and the output has a batch shape of (2,).
>>> i_out = jnp.array([0, 1, 1], dtype=jnp.int32)
The i_out array is used to map the result to the output indices.
>>> with cue.assume(cue.SO3, cue.ir_mul): ... x = cuex.RepArray("1", jnp.array([ ... [0.0, 1.0, 0.0], ... [0.0, 0.0, 1.0], ... [1.0, 0.0, 0.0], ... ])) >>> cuex.equivariant_tensor_product( ... e, ... x, ... indices=[None, i_out], ... output_batch_shape=(2,), ... ) {1: 0+1+2} [[ 1. ... ] [ 2. ... ]]