Segmented Polynomials#

Here is how the library is organized:

  • cue.SegmentedOperand objects represent arrays of numbers split into segments

  • cue.SegmentedTensorProduct objects describe how to multiply operands together but have no notion of input/output

  • cue.Operation objects introduce the concept of inputs and outputs, allowing for repeated inputs when needed

  • cue.SegmentedPolynomial combines these elements to create polynomials, typically with one SegmentedTensorProduct per degree

  • cue.EquivariantPolynomial adds cue.Rep labels to each input/output to specify their representations, which is essential for equivariant polynomials

Examples#

The submodule cue.descriptors contains many descriptors of equivariant polynomials. Each of those return a cue.EquivariantPolynomial.

Linear layer#

import cuequivariance as cue

irreps_in = cue.Irreps("O3", "32x0e + 32x1o")
irreps_out = cue.Irreps("O3", "16x0e + 48x1o")
cue.descriptors.linear(irreps_in, irreps_out)
╭ a=2048x0e b=32x0e+32x1o -> C=16x0e+48x1o
╰─ []·a[uv]·b[iu]➜C[iv] ─ num_paths=2 i={1, 3} u=32 v={16, 48}

In this example, the first operand is the weights, they are always scalars. There is 32 * 16 = 512 weights to connect the 0e together and 32 * 48 = 1536 weights to connect the 1o together. This gives a total of 2048 weights.

Spherical Harmonics#

cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3])
╭ a=1 -> B=0+1+2+3
│  []➜B[] ───────────── num_paths=1
│  []·a[]➜B[] ───────── num_paths=3
│  []·a[]·a[]➜B[] ───── num_paths=11
╰─ []·a[]·a[]·a[]➜B[] ─ num_paths=41

The spherical harmonics are polynomials of an input vector. This descriptor specifies the polynomials of degree 0, 1, 2 and 3.

Channel Wise Tensor Product#

irreps = cue.Irreps("O3", "0e + 1o + 2e")
cue.descriptors.channelwise_tensor_product(32 * irreps, irreps, irreps)
╭ a=352x0e b=32x0e+32x1o+32x2e c=0e+1o+2e -> D=32x0e+32x0e+32x0e+32x1o+32x1o+32x1o+32x1o+32x2e+32x2e+32x2e+32x2e
╰─ [ijk]·a[uv]·b[iu]·c[jv]➜D[kuv] ─ num_paths=11 i={1, 3, 5} j={1, 3, 5} k={1, 3, 5} u=32 v=1

Rotation#

cue.descriptors.yxy_rotation(cue.Irreps("O3", "32x0e + 32x1o"))
╭ a=3x0e b=3x0e c=3x0e d=32x0e+32x1o -> E=32x0e+32x1o
╰─ []·a[]·b[]·c[]·d[u]➜E[u] ─ num_paths=14 u=32

This case is a bit of an edge case, it is a rotation of the input by angles encoded as \(sin(\theta)\) and \(cos(\theta)\). See the function cuet.encode_rotation_angle for more details.

Symmetric Contraction#

irreps = 128 * cue.Irreps("O3", "0e + 1o + 2e")
e = cue.descriptors.symmetric_contraction(irreps, irreps, [0, 1, 2, 3])
e
╭ a=128x0e+384x0e+1024x0e+2048x0e b=128x0e+128x1o+128x2e -> C=128x0e+128x1o+128x2e
│  []·a[u]➜C[u] ──────────────── num_paths=1 u=128
│  []·a[u]·b[u]➜C[u] ─────────── num_paths=9 u=128
│  []·a[u]·b[u]·b[u]➜C[u] ────── num_paths=83 u=128
╰─ []·a[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=715 u=128

Execution on JAX#

import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance_jax as cuex

w = cuex.randn(jax.random.key(0), e.inputs[0])
x = cuex.randn(jax.random.key(1), e.inputs[1])

cuex.equivariant_polynomial(e, [w, x])
{0: 128x0e+128x1o+128x2e}
[ 2.0469868  -3.7131066  -1.9872632  ... -0.1839652   0.38468438
 -3.3312216 ]

The function cuex.randn generates random cuex.RepArray objects. The function cuex.equivariant_polynomial executes the tensor product. The output is a cuex.RepArray object.

Execution on PyTorch#

The same descriptor can be used in PyTorch using the class cuet.SegmentedPolynomial.

import torch
import cuequivariance as cue
import cuequivariance_torch as cuet

if torch.cuda.is_available():
    module = cuet.SegmentedPolynomial(e.polynomial)

    w = torch.randn(1, e.inputs[0].dim).cuda()
    x = torch.randn(1, e.inputs[1].dim).cuda()

    module([w, x])

Details#

An cue.EquivariantPolynomial is composed of two main components:

  1. Lists of cue.Rep objects that define the inputs and outputs of the polynomial

  2. A cue.SegmentedPolynomial that describes how to compute the polynomial

The cue.SegmentedPolynomial itself consists of:

  • A list of cue.SegmentedOperand objects that represent the operands used in the computation

  • A list of operations, where each operation is a pair containing:

This hierarchical structure allows for efficient representation and computation of equivariant polynomials. Below we can examine these components for a specific example:

e.inputs, e.outputs
((128x0e+384x0e+1024x0e+2048x0e, 128x0e+128x1o+128x2e),
 (128x0e+128x1o+128x2e,))
p = e.polynomial
p
╭ a=[3584:28⨯(128)] b=[1152:9⨯(128)] -> C=[1152:9⨯(128)]
│  []·a[u]➜C[u] ──────────────── num_paths=1 u=128
│  []·a[u]·b[u]➜C[u] ─────────── num_paths=9 u=128
│  []·a[u]·b[u]·b[u]➜C[u] ────── num_paths=83 u=128
╰─ []·a[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=715 u=128
p.inputs, p.outputs
((Operand(ndim=1 num_segments=28 dims=0=128),
  Operand(ndim=1 num_segments=9 dims=0=128)),
 (Operand(ndim=1 num_segments=9 dims=0=128),))
p.operations
((Operation((0, 2)), u,u sizes=3584,1152 num_segments=28,9 num_paths=1 u=128),
 (Operation((0, 1, 2)),
  u,u,u sizes=3584,1152,1152 num_segments=28,9,9 num_paths=9 u=128),
 (Operation((0, 1, 1, 2)),
  u,u,u,u sizes=3584,1152,1152,1152 num_segments=28,9,9,9 num_paths=83 u=128),
 (Operation((0, 1, 1, 1, 2)),
  u,u,u,u,u sizes=3584,1152,1152,1152,1152 num_segments=28,9,9,9,9 num_paths=715 u=128))