Segmented Tensor Product#

In this example, we are showing how to create a custom tensor product descriptor and execute it. First, we need to import the necessary modules.

import itertools
import numpy as np
import torch
import jax
import jax.numpy as jnp

import cuequivariance as cue
import cuequivariance_torch as cuet  # to execute the tensor product with PyTorch
import cuequivariance_jax as cuex    # to execute the tensor product with JAX

Basic Tools#

Creating a tensor product descriptor using the cue.SegmentedTensorProduct class.

d = cue.SegmentedTensorProduct.from_subscripts("a,ia,ja+ij")
print(d.to_text())
a,ia,ja+ij sizes=0,0,0 num_segments=0,0,0 num_paths=0 a= i= j=
operand #0 subscripts=a
  | a: []
operand #1 subscripts=ia
  | i: []
  | a: []
operand #2 subscripts=ja
  | j: []
  | a: []
Flop cost: 0->0 1->0 2->0
Memory cost: 0
No paths.

This descriptor has 3 operands.

d.num_operands
3

Its coefficients have indices “ij”.

d.coefficient_subscripts
'ij'

Adding segments to the two operands.

d.add_segment(0, (200,))
d.add_segments(1, [(3, 100), (5, 200)])
d.add_segments(2, [(1, 200), (1, 100)])
print(d.to_text())
a,ia,ja+ij sizes=200,1300,300 num_segments=1,2,2 num_paths=0 a={100, 200} i={3, 5} j=1
operand #0 subscripts=a
  | a: [200]
operand #1 subscripts=ia
  | i: [3, 5]
  | a: [100, 200]
operand #2 subscripts=ja
  | j: [1] * 2
  | a: [200, 100]
Flop cost: 0->0 1->0 2->0
Memory cost: 1800
No paths.

Observing that “j” is always set to 1, squeezing it.

d = d.squeeze_modes("j")
print(d.to_text())
a,ia,a+i sizes=200,1300,300 num_segments=1,2,2 num_paths=0 a={100, 200} i={3, 5}
operand #0 subscripts=a
  | a: [200]
operand #1 subscripts=ia
  | i: [3, 5]
  | a: [100, 200]
operand #2 subscripts=a
  | a: [200, 100]
Flop cost: 0->0 1->0 2->0
Memory cost: 1800
No paths.

Adding paths between the segments.

d.add_path(0, 1, 0, c=np.array([1.0, 2.0, 0.0, 0.0, 0.0]))
print(d.to_text())
a,ia,a+i sizes=200,1300,300 num_segments=1,2,2 num_paths=1 a={100, 200} i={3, 5}
operand #0 subscripts=a
  | a: [200]
operand #1 subscripts=ia
  | i: [3, 5]
  | a: [100, 200]
operand #2 subscripts=a
  | a: [200, 100]
Flop cost: 0->2200 1->1200 2->2200
Memory cost: 1800
Path indices: 0 1 0
Path coefficients:
[1.0 2.0 0.0 0.0 0.0]

Flattening the index “i” of the coefficients.

d = d.flatten_modes("i")
# d = d.flatten_coefficient_modes()
print(d.to_text())
a,a,a sizes=200,1300,300 num_segments=1,8,2 num_paths=2 a={100, 200}
operand #0 subscripts=a
  | a: [200]
operand #1 subscripts=a
  | a: [100, 100, 100, 200, 200, 200, 200, 200]
operand #2 subscripts=a
  | a: [200, 100]
Flop cost: 0->800 1->800 2->800
Memory cost: 1800
Path indices: 0 3 0, 0 4 0
Path coefficients: [1.0 2.0]

Equivalently, flatten_coefficient_modes can be used.

Equivariant Linear Layer#

Now, we are creating a custom tensor product descriptor that represents the tensor product of the two representations. See Groups and Representations for more information on irreps.

irreps1 = cue.Irreps("O3", "32x0e + 32x1o")
irreps2 = cue.Irreps("O3", "16x0e + 48x1o")

The tensor product descriptor is created step by step. First, we are creating an empty descriptor given its subscripts. In the case of the linear layer, we have 3 operands: the weight, the input, and the output. The subscripts of this tensor product are “uv,iu,iv” where “uv” represents the modes of the weight, “iu” represents the modes of the input, and “iv” represents the modes of the output.

d = cue.SegmentedTensorProduct.from_subscripts("uv,iu,iv")
d
uv,iu,iv sizes=0,0,0 num_segments=0,0,0 num_paths=0 i= u= v=

Each operand of the tensor product descriptor has a list of segments. We can add segments to the descriptor using the add_segment method. We are adding the segments of the input and output representations to the descriptor.

for mul, ir in irreps1:
   d.add_segment(1, (ir.dim, mul))
for mul, ir in irreps2:
   d.add_segment(2, (ir.dim, mul))

d
uv,iu,iv sizes=0,128,160 num_segments=0,2,2 num_paths=0 i={1, 3} u=32 v={16, 48}

Enumerating all the possible pairs of irreps and adding weight segements and paths between them when the irreps are the same.

for (i1, (mul1, ir1)), (i2, (mul2, ir2)) in itertools.product(
   enumerate(irreps1), enumerate(irreps2)
):
   if ir1 == ir2:
      d.add_path(None, i1, i2, c=1.0)

d
uv,iu,iv sizes=2048,128,160 num_segments=2,2,2 num_paths=2 i={1, 3} u=32 v={16, 48}

Printing the paths.

d.paths
(op0[0]*op1[0]*op2[0]*1., op0[1]*op1[1]*op2[1]*1.)

Normalizing the paths for the last operand such that the output is normalized to variance 1.

d = d.normalize_paths_for_operand(-1)
d.paths
(op0[0]*op1[0]*op2[0]*0.18, op0[1]*op1[1]*op2[1]*0.18)

As we can see, the paths coefficients have been normalized.

Now we are creating a tensor product from the descriptor and executing it. In PyTorch, we can use the cuet.TensorProduct class.

linear_torch = cuet.TensorProduct(d, use_fallback=True)
linear_torch
TensorProduct(uv,iu,iv sizes=2048,128,160 num_segments=2,2,2 num_paths=2 i={1, 3} u=32 v={16, 48} (without CUDA kernel))

Now we can execute the linear layer with random input and weight tensors.

w = torch.randn(1, d.operands[0].size)
x1 = torch.randn(3000, irreps1.dim)

x2 = linear_torch(w, x1)

assert x2.shape == (3000, irreps2.dim)

Now we are verifying that the output is well normalized.

x2.var()
tensor(0.9795)

In JAX, we can use the cuex.segmented_polynomial function.

w = jax.random.normal(jax.random.key(0), (d.operands[0].size,))
x1 = jax.random.normal(jax.random.key(1), (3000, irreps1.dim))

[x2] = cuex.segmented_polynomial(
   cue.SegmentedPolynomial.eval_last_operand(d),
   [w, x1],
   [jax.ShapeDtypeStruct((3000, irreps2.dim), jnp.float32)],
)

assert x2.shape == (3000, irreps2.dim)
x2.var()
Array(1.0266902, dtype=float32)