Segmented Polynomials#
Here is how the library is organized:
cue.SegmentedOperandobjects represent arrays of numbers split into segmentscue.SegmentedTensorProductobjects describe how to multiply operands together but have no notion of input/outputcue.Operationobjects introduce the concept of inputs and outputs, allowing for repeated inputs when neededcue.SegmentedPolynomialcombines these elements to create polynomials, typically with one SegmentedTensorProduct per degreecue.EquivariantPolynomialaddscue.Replabels to each input/output to specify their representations, which is essential for equivariant polynomials
Examples#
The submodule cue.descriptors contains many descriptors of equivariant polynomials. Each of those return a cue.EquivariantPolynomial.
Linear layer#
import cuequivariance as cue
irreps_in = cue.Irreps("O3", "32x0e + 32x1o")
irreps_out = cue.Irreps("O3", "16x0e + 48x1o")
cue.descriptors.linear(irreps_in, irreps_out)
╭ a=2048x0e b=32x0e+32x1o -> C=16x0e+48x1o
╰─ []·a[uv]·b[iu]➜C[iv] ─ num_paths=2 i={1, 3} u=32 v={16, 48}
In this example, the first operand is the weights, they are always scalars.
There is 32 * 16 = 512 weights to connect the 0e together and 32 * 48 = 1536 weights to connect the 1o together. This gives a total of 2048 weights.
Spherical Harmonics#
cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3])
╭ a=1 -> B=0+1+2+3
│ []➜B[] ───────────── num_paths=1
│ []·a[]➜B[] ───────── num_paths=3
│ []·a[]·a[]➜B[] ───── num_paths=11
╰─ []·a[]·a[]·a[]➜B[] ─ num_paths=41
The spherical harmonics are polynomials of an input vector. This descriptor specifies the polynomials of degree 0, 1, 2 and 3.
Channel Wise Tensor Product#
irreps = cue.Irreps("O3", "0e + 1o + 2e")
cue.descriptors.channelwise_tensor_product(32 * irreps, irreps, irreps)
╭ a=352x0e b=32x0e+32x1o+32x2e c=0e+1o+2e -> D=32x0e+32x0e+32x0e+32x1o+32x1o+32x1o+32x1o+32x2e+32x2e+32x2e+32x2e
╰─ [ijk]·a[uv]·b[iu]·c[jv]➜D[kuv] ─ num_paths=11 i={1, 3, 5} j={1, 3, 5} k={1, 3, 5} u=32 v=1
Rotation#
cue.descriptors.yxy_rotation(cue.Irreps("O3", "32x0e + 32x1o"))
╭ a=3x0e b=3x0e c=3x0e d=32x0e+32x1o -> E=32x0e+32x1o
╰─ []·a[]·b[]·c[]·d[u]➜E[u] ─ num_paths=14 u=32
This case is a bit of an edge case, it is a rotation of the input by angles encoded as \(sin(\theta)\) and \(cos(\theta)\). See the function cuet.encode_rotation_angle for more details.
Symmetric Contraction#
irreps = 128 * cue.Irreps("O3", "0e + 1o + 2e")
e = cue.descriptors.symmetric_contraction(irreps, irreps, [0, 1, 2, 3])
e
╭ a=128x0e+384x0e+1024x0e+2048x0e b=128x0e+128x1o+128x2e -> C=128x0e+128x1o+128x2e
│ []·a[u]➜C[u] ──────────────── num_paths=1 u=128
│ []·a[u]·b[u]➜C[u] ─────────── num_paths=9 u=128
│ []·a[u]·b[u]·b[u]➜C[u] ────── num_paths=83 u=128
╰─ []·a[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=715 u=128
Execution on JAX#
import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance_jax as cuex
w = cuex.randn(jax.random.key(0), e.inputs[0])
x = cuex.randn(jax.random.key(1), e.inputs[1])
cuex.equivariant_polynomial(e, [w, x], method="uniform_1d")
{0: 128x0e+128x1o+128x2e}
[ 2.0469868 -3.7131066 -1.9872632 ... -0.1839652 0.38468438
-3.3312216 ]
The function cuex.randn generates random cuex.RepArray objects.
The function cuex.equivariant_polynomial executes the tensor product.
The output is a cuex.RepArray object.
Execution on PyTorch#
The same descriptor can be used in PyTorch using the class cuet.SegmentedPolynomial.
import torch
import cuequivariance as cue
import cuequivariance_torch as cuet
module = cuet.SegmentedPolynomial(e.polynomial, method="uniform_1d")
w = torch.randn(1, e.inputs[0].dim)
x = torch.randn(1, e.inputs[1].dim)
module([w, x])
[tensor([[-0.0259, -1.1032, -2.8918, ..., -0.0537, 1.8772, -0.2967]])]
Details#
An cue.EquivariantPolynomial is composed of two main components:
Lists of
cue.Repobjects that define the inputs and outputs of the polynomialA
cue.SegmentedPolynomialthat describes how to compute the polynomial
The cue.SegmentedPolynomial itself consists of:
A list of
cue.SegmentedOperandobjects that represent the operands used in the computation- A list of operations, where each operation is a pair containing:
An
cue.Operationobject that defines what operation to performA
cue.SegmentedTensorProductthat specifies how to perform the tensor product
This hierarchical structure allows for efficient representation and computation of equivariant polynomials. Below we can examine these components for a specific example:
e.inputs, e.outputs
((128x0e+384x0e+1024x0e+2048x0e, 128x0e+128x1o+128x2e),
(128x0e+128x1o+128x2e,))
p = e.polynomial
p
╭ a=[3584:28⨯(128)] b=[1152:9⨯(128)] -> C=[1152:9⨯(128)]
│ []·a[u]➜C[u] ──────────────── num_paths=1 u=128
│ []·a[u]·b[u]➜C[u] ─────────── num_paths=9 u=128
│ []·a[u]·b[u]·b[u]➜C[u] ────── num_paths=83 u=128
╰─ []·a[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=715 u=128
p.inputs, p.outputs
((Operand(ndim=1 num_segments=28 dims=0=128),
Operand(ndim=1 num_segments=9 dims=0=128)),
(Operand(ndim=1 num_segments=9 dims=0=128),))
p.operations
((Operation((0, 2)), u,u sizes=3584,1152 num_segments=28,9 num_paths=1 u=128),
(Operation((0, 1, 2)),
u,u,u sizes=3584,1152,1152 num_segments=28,9,9 num_paths=9 u=128),
(Operation((0, 1, 1, 2)),
u,u,u,u sizes=3584,1152,1152,1152 num_segments=28,9,9,9 num_paths=83 u=128),
(Operation((0, 1, 1, 1, 2)),
u,u,u,u,u sizes=3584,1152,1152,1152,1152 num_segments=28,9,9,9,9 num_paths=715 u=128))