Equivariant Tensor Product#
The submodule cuequivariance.descriptors
contains many descriptors of Equivariant Tensor Products represented by the class cuequivariance.EquivariantTensorProduct
.
Examples#
Linear layer#
import cuequivariance as cue
cue.descriptors.linear(cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "16x0e + 48x1o"))
EquivariantTensorProduct(2048x0e x 32x0e+32x1o -> 16x0e+48x1o)
In this example, the first operand is the weights, they are always scalars.
There is 32 * 16 = 512
weights to connect the 0e
together and 32 * 48 = 1536
weights to connect the 1o
together. This gives a total of 2048
weights.
Spherical Harmonics#
cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3])
EquivariantTensorProduct((1)^(0..3) -> 0+1+2+3)
The spherical harmonics are polynomials of an input vector. This descriptor specifies the polynomials of degree 0, 1, 2 and 3.
Rotation#
cue.descriptors.yxy_rotation(cue.Irreps("O3", "32x0e + 32x1o"))
EquivariantTensorProduct(3x0e x 3x0e x 3x0e x 32x0e+32x1o -> 32x0e+32x1o)
This case is a bit of an edge case, it is a rotation of the input by angles encoded as \(sin(\theta)\) and \(cos(\theta)\). See the function cuet.encode_rotation_angle
for more details.
Execution on JAX#
import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance_jax as cuex
e = cue.descriptors.linear(
cue.Irreps("O3", "32x0e + 32x1o"),
cue.Irreps("O3", "8x0e + 4x1o")
)
w = cuex.randn(jax.random.key(0), e.inputs[0])
x = cuex.randn(jax.random.key(1), e.inputs[1])
cuex.equivariant_tensor_product(e, w, x)
{0: 8x0e+4x1o}
[-0.56710666 0.29934764 1.438811 -1.0761446 -0.16420852 0.6024779
-1.7548201 0.3891445 0.03765802 -0.03795518 -0.750764 -3.2584484
0.6283557 0.09663387 -0.42426407 -0.8612407 0.4686108 -0.9862214
-0.31201616 0.8071706 ]
The function cuex.randn
generates random cuex.RepArray
objects.
The function cuex.equivariant_tensor_product
executes the tensor product.
The output is a cuex.RepArray
object.
Execution on PyTorch#
We can execute an cuequivariance.EquivariantTensorProduct
with PyTorch.
import torch
import cuequivariance as cue
import cuequivariance_torch as cuet
e = cue.descriptors.linear(
cue.Irreps("O3", "32x0e + 32x1o"),
cue.Irreps("O3", "8x0e + 4x1o")
)
module = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, use_fallback=True)
w = torch.randn(1, e.inputs[0].dim)
x = torch.randn(1, e.inputs[1].dim)
module(w, x)
tensor([[ 2.3891, 0.7179, 1.5534, 1.8922, -0.5494, 0.2414, -1.2353, 2.5942,
-0.3150, 1.1342, -0.5765, -0.9785, 0.3172, -0.5993, -0.4881, 0.9823,
-1.3407, -0.8895, -0.8482, 1.4145]])
Note that you have to specify the layout. If the layout specified is different from the one in the descriptor, the module will transpose the inputs/output to match the layout.