Equivariant Tensor Product#
The submodule cuequivariance.descriptors
contains many descriptors of Equivariant Tensor Products represented by the class cuequivariance.EquivariantTensorProduct
.
Examples#
Linear layer#
import cuequivariance as cue
cue.descriptors.linear(cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "16x0e + 48x1o"))
╭ a=2048x0e b=32x0e+32x1o -> C=16x0e+48x1o
╰─ []·a[uv]·b[iu]➜C[iv] ─ num_paths=2 i={1, 3} u=32 v={16, 48}
In this example, the first operand is the weights, they are always scalars.
There is 32 * 16 = 512
weights to connect the 0e
together and 32 * 48 = 1536
weights to connect the 1o
together. This gives a total of 2048
weights.
Spherical Harmonics#
cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3])
╭ a=1 -> B=0+1+2+3
│ []➜B[] ───────────── num_paths=1
│ []·a[]➜B[] ───────── num_paths=3
│ []·a[]·a[]➜B[] ───── num_paths=11
╰─ []·a[]·a[]·a[]➜B[] ─ num_paths=41
The spherical harmonics are polynomials of an input vector. This descriptor specifies the polynomials of degree 0, 1, 2 and 3.
Rotation#
cue.descriptors.yxy_rotation(cue.Irreps("O3", "32x0e + 32x1o"))
╭ a=3x0e b=3x0e c=3x0e d=32x0e+32x1o -> E=32x0e+32x1o
╰─ []·a[]·b[]·c[]·d[u]➜E[u] ─ num_paths=14 u=32
This case is a bit of an edge case, it is a rotation of the input by angles encoded as \(sin(\theta)\) and \(cos(\theta)\). See the function cuet.encode_rotation_angle
for more details.
Execution on JAX#
import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance_jax as cuex
e = cue.descriptors.linear(
cue.Irreps("O3", "32x0e + 32x1o"),
cue.Irreps("O3", "8x0e + 4x1o")
)
w = cuex.randn(jax.random.key(0), e.inputs[0])
x = cuex.randn(jax.random.key(1), e.inputs[1])
cuex.equivariant_polynomial(e, [w, x])
{0: 8x0e+4x1o}
[-0.5671067 0.2993476 1.438811 -1.0761446 -0.16420853 0.60247785
-1.7548201 0.38914463 0.03765786 -0.03795541 -0.7507639 -3.2584481
0.628356 0.09663385 -0.42426404 -0.86124074 0.46861085 -0.9862213
-0.3120161 0.8071707 ]
The function cuex.randn
generates random cuex.RepArray
objects.
The function cuex.equivariant_polynomial
executes the tensor product.
The output is a cuex.RepArray
object.
Execution on PyTorch#
We can execute an cuequivariance.EquivariantTensorProduct
with PyTorch.
import torch
import cuequivariance as cue
import cuequivariance_torch as cuet
e = cue.descriptors.linear(
cue.Irreps("O3", "32x0e + 32x1o"),
cue.Irreps("O3", "8x0e + 4x1o")
)
module = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, use_fallback=True)
w = torch.randn(1, e.inputs[0].dim)
x = torch.randn(1, e.inputs[1].dim)
module(w, x)
tensor([[-0.4634, -1.1188, 0.2747, -0.5108, 0.3437, 0.2699, 0.8779, 0.1519,
0.0092, -0.3050, -2.7978, -0.6944, -1.1833, 0.3065, 0.9834, -0.7099,
-1.2396, 0.2767, 2.0055, 0.4753]])
Note that you have to specify the layout. If the layout specified is different from the one in the descriptor, the module will transpose the inputs/output to match the layout.