Equivariant Tensor Product#

The submodule cuequivariance.descriptors contains many descriptors of Equivariant Tensor Products represented by the class cuequivariance.EquivariantTensorProduct.

Examples#

Linear layer#

import cuequivariance as cue

cue.descriptors.linear(cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "16x0e + 48x1o"))
EquivariantTensorProduct(2048x0e x 32x0e+32x1o -> 16x0e+48x1o)

In this example, the first operand is the weights, they are always scalars. There is 32 * 16 = 512 weights to connect the 0e together and 32 * 48 = 1536 weights to connect the 1o together. This gives a total of 2048 weights.

Spherical Harmonics#

cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3])
EquivariantTensorProduct((1)^(0..3) -> 0+1+2+3)

The spherical harmonics are polynomials of an input vector. This descriptor specifies the polynomials of degree 0, 1, 2 and 3.

Rotation#

cue.descriptors.yxy_rotation(cue.Irreps("O3", "32x0e + 32x1o"))
EquivariantTensorProduct(3x0e x 3x0e x 3x0e x 32x0e+32x1o -> 32x0e+32x1o)

This case is a bit of an edge case, it is a rotation of the input by angles encoded as \(sin(\theta)\) and \(cos(\theta)\). See the function cuet.encode_rotation_angle for more details.

Execution on JAX#

import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance_jax as cuex

e = cue.descriptors.linear(
    cue.Irreps("O3", "32x0e + 32x1o"),
    cue.Irreps("O3", "8x0e + 4x1o")
)
w = cuex.randn(jax.random.key(0), e.inputs[0])
x = cuex.randn(jax.random.key(1), e.inputs[1])

cuex.equivariant_tensor_product(e, w, x)
{0: 8x0e+4x1o}
[-0.56710666  0.29934764  1.438811   -1.0761446  -0.16420852  0.6024779
 -1.7548201   0.3891445   0.03765802 -0.03795518 -0.750764   -3.2584484
  0.6283557   0.09663387 -0.42426407 -0.8612407   0.4686108  -0.9862214
 -0.31201616  0.8071706 ]

The function cuex.randn generates random cuex.RepArray objects. The function cuex.equivariant_tensor_product executes the tensor product. The output is a cuex.RepArray object.

Execution on PyTorch#

We can execute an cuequivariance.EquivariantTensorProduct with PyTorch.

import torch
import cuequivariance as cue
import cuequivariance_torch as cuet

e = cue.descriptors.linear(
    cue.Irreps("O3", "32x0e + 32x1o"),
    cue.Irreps("O3", "8x0e + 4x1o")
)
module = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, use_fallback=True)

w = torch.randn(1, e.inputs[0].dim)
x = torch.randn(1, e.inputs[1].dim)

module(w, x)
tensor([[ 2.3891,  0.7179,  1.5534,  1.8922, -0.5494,  0.2414, -1.2353,  2.5942,
         -0.3150,  1.1342, -0.5765, -0.9785,  0.3172, -0.5993, -0.4881,  0.9823,
         -1.3407, -0.8895, -0.8482,  1.4145]])

Note that you have to specify the layout. If the layout specified is different from the one in the descriptor, the module will transpose the inputs/output to match the layout.