Equivariant Tensor Product#

The submodule cuequivariance.descriptors contains many descriptors of Equivariant Tensor Products represented by the class cuequivariance.EquivariantTensorProduct.

Examples#

Linear layer#

import cuequivariance as cue

cue.descriptors.linear(cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "16x0e + 48x1o"))
╭ a=2048x0e b=32x0e+32x1o -> C=16x0e+48x1o
╰─ []·a[uv]·b[iu]➜C[iv] ─ num_paths=2 i={1, 3} u=32 v={16, 48}

In this example, the first operand is the weights, they are always scalars. There is 32 * 16 = 512 weights to connect the 0e together and 32 * 48 = 1536 weights to connect the 1o together. This gives a total of 2048 weights.

Spherical Harmonics#

cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3])
╭ a=1 -> B=0+1+2+3
│  []➜B[] ───────────── num_paths=1
│  []·a[]➜B[] ───────── num_paths=3
│  []·a[]·a[]➜B[] ───── num_paths=11
╰─ []·a[]·a[]·a[]➜B[] ─ num_paths=41

The spherical harmonics are polynomials of an input vector. This descriptor specifies the polynomials of degree 0, 1, 2 and 3.

Rotation#

cue.descriptors.yxy_rotation(cue.Irreps("O3", "32x0e + 32x1o"))
╭ a=3x0e b=3x0e c=3x0e d=32x0e+32x1o -> E=32x0e+32x1o
╰─ []·a[]·b[]·c[]·d[u]➜E[u] ─ num_paths=14 u=32

This case is a bit of an edge case, it is a rotation of the input by angles encoded as \(sin(\theta)\) and \(cos(\theta)\). See the function cuet.encode_rotation_angle for more details.

Execution on JAX#

import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance_jax as cuex

e = cue.descriptors.linear(
    cue.Irreps("O3", "32x0e + 32x1o"),
    cue.Irreps("O3", "8x0e + 4x1o")
)
w = cuex.randn(jax.random.key(0), e.inputs[0])
x = cuex.randn(jax.random.key(1), e.inputs[1])

cuex.equivariant_polynomial(e, [w, x])
{0: 8x0e+4x1o}
[-0.5671067   0.2993476   1.438811   -1.0761446  -0.16420853  0.60247785
 -1.7548201   0.38914463  0.03765786 -0.03795541 -0.7507639  -3.2584481
  0.628356    0.09663385 -0.42426404 -0.86124074  0.46861085 -0.9862213
 -0.3120161   0.8071707 ]

The function cuex.randn generates random cuex.RepArray objects. The function cuex.equivariant_polynomial executes the tensor product. The output is a cuex.RepArray object.

Execution on PyTorch#

We can execute an cuequivariance.EquivariantTensorProduct with PyTorch.

import torch
import cuequivariance as cue
import cuequivariance_torch as cuet

e = cue.descriptors.linear(
    cue.Irreps("O3", "32x0e + 32x1o"),
    cue.Irreps("O3", "8x0e + 4x1o")
)
module = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, use_fallback=True)

w = torch.randn(1, e.inputs[0].dim)
x = torch.randn(1, e.inputs[1].dim)

module(w, x)
tensor([[-0.4634, -1.1188,  0.2747, -0.5108,  0.3437,  0.2699,  0.8779,  0.1519,
          0.0092, -0.3050, -2.7978, -0.6944, -1.1833,  0.3065,  0.9834, -0.7099,
         -1.2396,  0.2767,  2.0055,  0.4753]])

Note that you have to specify the layout. If the layout specified is different from the one in the descriptor, the module will transpose the inputs/output to match the layout.