In depth example: MACE#

Data Structure#

The layout#

check Data Layouts guide for more information about this.

cuEquivariance offers the possibility to use a more efficient layout for the irreps. The old layout, compatible with e3nn operations, is called cue.mul_ir, while the new layout is called cue.ir_mul.

Irreps#

Check Group representations guide for more information about this.

If we stick to the old layout, we can equivalently define e3nn and cue Irreps as:

from e3nn import o3
import cuequivariance as cue

old_irreps = o3.Irreps('1x0e+1x1o')
new_irreps = cue.Irreps(cue.O3, '1x0e+1x1o')

A note about the O3 group: The official MACE implementation uses e3nn version 0.4.4, which employs a slightly different group definition with respect to more recent e3nn versions (and cuEquivariance). If compatibility with old models is desired, it is possible to enforce the use of this group by defining the new group:

from typing import Iterator
import numpy as np
import itertools

class O3_e3nn(cue.O3):
    def __mul__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> Iterator["O3_e3nn"]:
        return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)]

    @classmethod
    def clebsch_gordan(
        cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn"
    ) -> np.ndarray:
        rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3)

        if rep1.p * rep2.p == rep3.p:
            return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(rep3.dim)
        else:
            return np.zeros((0, rep1.dim, rep2.dim, rep3.dim))

    def __lt__(rep1: "O3_e3nn", rep2: "O3_e3nn") -> bool:
        rep2 = rep1._from(rep2)
        return (rep1.l, rep1.p) < (rep2.l, rep2.p)

    @classmethod
    def iterator(cls) -> Iterator["O3_e3nn"]:
        for l in itertools.count(0):
            yield O3_e3nn(l=l, p=1 * (-1) ** l)
            yield O3_e3nn(l=l, p=-1 * (-1) ** l)

O3_e3nn should then be used throughout the code in place of cue.O3, like in the following example:

from cuequivariance.experimental.e3nn import O3_e3nn  # also available here
new_irreps = cue.Irreps(O3_e3nn, '1x0e+1x1o')

Here are some snippets useful for accelerating MACE, in PyTorch and JAX.

Using PyTorch#

Note: in the following we will refer to the cuequivariance library as cue and to the cuequivariance_torch library as cuet.

To accelerate MACE, we want to substitute the e3nn operations with the equivalent cuEquivariance operations.

In particular, there are 4 operations within MACE:

  • SymmetricContractioncuet.SymmetricContraction

  • tp_out_irreps_with_instructions + e3nn.o3.TensorProductcuet.ChannelWiseTensorProduct

  • e3nn.o3.Linearcuet.Linear

  • e3nn.o3.FullyConnectedTensorProductcuet.FullyConnectedTensorProduct

All of these have a cuequivariance counterpart, but for now only the first two result in a significant performance improvement. The layout can be changed throughout the code, but this requires all operations to be upgraded to their cuet counterparts. For the operations where the kernel is not yet optimized, we can fall back to the FX implementation (implementation using torch.fx like in e3nn) with a simple flag.

Common options#

These general rules are valid for most operations:

layout

cue.mul_ir or cue.ir_mul, as explained above

dtype

torch.float32 or torch.float64

use_fallback

bool, use this when calling the function to select the FX implementation instead of the kernel

We can thus set some of this common options:

import torch
import cuequivariance as cue
import cuequivariance_torch as cuet

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

dtype = torch.float32  # or torch.float64

SymmetricContraction#

The original SymmetricContraction was an operation written specifically for MACE. It performs operations on a single input_feature repeated multiple times, but uses a second input (or attribute, 1-hot encoded) to select weights depending on the atomic species.

For performance reasons, the cuequivariance implementation uses direct indexing in place of 1-hot vectors, i.e. the attributes are now integers, indicating the index of each atom in the species list.

The SymmetricContraction code should look like this:

feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e")
target_irreps = cue.Irreps("O3", "32x0e + 32x1o")

# OLD FUNCTION DEFINITION:
# symmetric_contractions = SymmetricContraction(
#     irreps_in=feats_irreps,
#     irreps_out=target_irreps,
#     correlation=3,
#     num_elements=10,
# )

# NEW FUNCTION DEFINITION:
symmetric_contractions = cuet.SymmetricContraction(
    irreps_in=feats_irreps,
    irreps_out=target_irreps,
    contraction_degree=3,
    num_elements=10,
    layout_in=cue.ir_mul,
    layout_out=cue.mul_ir,
    original_mace=True,
    dtype=dtype,
    device=device,
)

node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=dtype, device=device)
# with node_attrs_index being the index version of node_attrs, sth like:
# node_attrs_index = torch.nonzero(node_attrs)[:, 1].int()
node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32, device=device)

# OLD CALL:
# symmetric_contractions(node_feats, node_attrs)

# NEW CALL:
node_feats = torch.transpose(node_feats, 1, 2).flatten(1)
symmetric_contractions(node_feats, node_attrs_index)
tensor([[ -1.2508,   1.3609,   0.3291,  ...,   0.7374,   1.9232,  -0.7534],
        [ 31.6724,  -1.4054,   2.4446,  ...,  -0.2629,  -0.5113,  -2.7330],
        [  3.3623,  -0.5758,  -5.0324,  ...,   9.5122,   6.0776,  -0.5560],
        ...,
        [  1.9003,  -8.7942,  16.6021,  ..., -15.1150,   1.3935,  17.6860],
        [  1.6112,  -6.2048,  -4.8934,  ...,  13.9444,  -8.0758, -12.5673],
        [ -3.4384,   0.1033,  -8.1656,  ...,   3.3409,   0.9181,  -6.4427]],
       device='cuda:0',
       grad_fn=<GeneratedBackwardFor_cuequivariance_ops_torch_segmented_transpose_primitive_defaultBackward>)

We can see that in this case we can specify separately the layout_in and layout_out. In this particular case, we have selected to use cue.ir_mul as an input, but have explicitly performed the transposition before calling the operation. If you were using this layout throughout, this would not be needed.

The flag original_mace ensures compatibility with the old SymmetricContraction, where operations had a slightly different order than the new version.

ChannelWiseTensorProduct#

The ChannelWiseTensorProduct replaces the custom operation that was obtained in MACE by defining custom instructions and calling a TensorProduct. This particular operation was also called with external weights computed through a MLP. The same can be done in cuet.

The new version for this part of the code will thus read:

feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e")
edge_attrs_irreps = target_irreps = "0e + 1o + 2e + 3o"
edge_feats = torch.randn(128, feats_irreps.dim, dtype=dtype, device=device)
edge_vectors = torch.randn(128, 3, dtype=dtype, device=device)

edge_sh = cuet.spherical_harmonics([0, 1, 2, 3], edge_vectors)

# OLD FUNCTION DEFINITION
# irreps_mid, instructions = tp_out_irreps_with_instructions(
#     feats_irreps,
#     edge_attrs_irreps,
#     target_irreps,
# )
# conv_tp = o3.TensorProduct(
#     feats_irreps,
#     edge_attrs_irreps,
#     irreps_mid,
#     instructions=instructions,
#     shared_weights=False,
#     internal_weights=False,
# )

# NEW FUNCTION DEFINITION (single function)
conv_tp = cuet.ChannelWiseTensorProduct(
    irreps_in1=feats_irreps,
    irreps_in2=cue.Irreps("O3", edge_attrs_irreps),
    filter_irreps_out=cue.Irreps("O3", target_irreps),
    shared_weights=False,
    internal_weights=False,
    layout=cue.mul_ir,
    math_dtype=dtype,
    device=device,
)

# Weights (would normally come from conv_tp_weights)
tp_weights = torch.randn(128, conv_tp.weight_numel, dtype=dtype, device=device)

# OLD CALL:
# mji = conv_tp(edge_feats, edge_sh, tp_weights)

# NEW CALL: (unchanged)
conv_tp(edge_feats, edge_sh, tp_weights)
tensor([[ 8.9842e-01,  5.8307e-02, -2.3750e-01,  ..., -6.3765e-01,
         -8.0892e-01, -1.4038e+00],
        [-6.9469e-01, -2.1151e-01, -4.1547e-01,  ...,  5.5980e-01,
          2.0222e-01, -1.9801e-01],
        [-7.9954e-01, -7.5550e-02, -1.8983e-02,  ...,  1.0786e+00,
         -5.2210e-02,  9.8478e-01],
        ...,
        [ 1.2721e-01,  1.1503e+00, -4.4172e-02,  ...,  1.4831e-01,
         -2.4400e-01,  2.0115e-01],
        [-2.4462e-03,  1.1809e+00,  1.6359e+00,  ..., -1.0630e+00,
         -1.4110e+00,  2.2241e+00],
        [ 1.2139e-02,  2.5792e-01,  2.6371e+00,  ..., -3.3787e-02,
         -6.8150e-02, -3.8043e-02]], device='cuda:0')

Linear#

This is one of the simplest operations, and it is essentially unchanged. Depending on the irreps size, the kernel might not improve above the naive implementation, we thus show an example where the fallback is employed.

feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e")

# OLD FUNCTION DEFINITION:
# linear = o3.Linear(
#     feats_irreps,
#     feats_irreps,
#     internal_weights=True,
#     shared_weights=True,
# )

# NEW FUNCTION DEFINITION:
linear = cuet.Linear(
    irreps_in=feats_irreps,
    irreps_out=feats_irreps,
    layout=cue.mul_ir,
    internal_weights=True,
    shared_weights=True,
    dtype=dtype,
    device=device,
)

node_feats = torch.randn(128, feats_irreps.dim, dtype=dtype, device=device)

# OLD CALL:
# linear(node_feats)

# NEW CALL: (unchanged, using fallback)
linear(node_feats, use_fallback=True)
tensor([[ 1.3389e+00,  6.2552e-01, -3.0618e-01,  ..., -1.3058e+00,
         -7.3051e-01,  1.6218e+00],
        [ 1.1776e+00,  2.5962e-01,  1.3329e+00,  ...,  1.1841e+00,
          1.8994e-01, -3.5679e-01],
        [-8.9027e-01, -1.9882e-01,  1.3197e+00,  ..., -2.5388e-01,
          1.9484e+00, -2.4207e-01],
        ...,
        [-9.9497e-01,  1.7081e-03, -1.3681e+00,  ..., -1.9742e+00,
          8.5468e-01, -1.2922e+00],
        [-7.7185e-01, -1.7311e-02, -9.1431e-01,  ...,  5.8499e-01,
         -6.7907e-01,  1.3744e+00],
        [-1.1650e+00, -1.1011e-01,  1.3415e-01,  ..., -6.3652e-01,
          9.4782e-01,  8.1117e-01]], device='cuda:0', grad_fn=<CatBackward0>)

FullyConnectedTensorProduct#

The FullyConnectedTensorProduct is used in MACE for the skip-tp operation. In this case, the “node attributes” used to select the weights are still accepted as 1-hot. This operation is also essentially unchanged, and we show a version using the fallback.

feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e")
attrs_irreps = cue.Irreps("O3", "10x0e")

# OLD FUNCTION DEFINITION:
# skip_tp = o3.FullyConnectedTensorProduct(
#     feats_irreps,
#     attrs_irreps,
#     feats_irreps,
#     internal_weights=True,
#     shared_weights=True,
# )

# NEW FUNCTION DEFINITION:
skip_tp = cuet.FullyConnectedTensorProduct(
    feats_irreps,
    attrs_irreps,
    feats_irreps,
    layout=cue.mul_ir,
    internal_weights=True,
    shared_weights=True,
    dtype=dtype,
    device=device,
)

node_feats = torch.randn(128, feats_irreps.dim, dtype=dtype, device=device)
node_attrs = torch.nn.functional.one_hot(torch.randint(0, 10, (128,), dtype=torch.int64, device=device), 10).to(dtype)

# OLD CALL:
# skip_tp(node_feats, node_attrs)

# NEW CALL: (unchanged, using fallback)
skip_tp(node_feats, node_attrs, use_fallback=True)
tensor([[-4.3302e-02,  2.9739e-03, -6.5229e-01,  ..., -1.7563e-01,
          2.7892e-01, -2.2812e-01],
        [-5.0918e-01,  6.8811e-02,  1.4868e-01,  ..., -2.0992e-01,
         -1.4190e-01, -1.1486e-04],
        [ 4.4665e-01,  2.9435e-01, -9.6858e-02,  ...,  5.4243e-01,
         -2.2677e-01, -5.6310e-01],
        ...,
        [ 6.6191e-02, -8.5394e-02, -1.3618e-01,  ...,  3.4933e-01,
         -4.9362e-01,  3.8864e-02],
        [ 8.3473e-01, -2.5527e-01,  9.2828e-02,  ..., -2.2238e-01,
         -2.2132e-01,  4.1267e-01],
        [ 3.9997e-01,  5.3574e-01,  3.2328e-01,  ...,  3.5927e-01,
          5.1885e-01, -3.0160e-01]], device='cuda:0', grad_fn=<CatBackward0>)

Using JAX#

Note: In the following, we will refer to the cuequivariance library as cue and the cuequivariance_jax library as cuex.

The following code snippets demonstrate the main components of a MACE layer implemented in JAX. For the sake of simplicity, we will not implement the entire MACE layer, but rather focus on the main components. First, we import the necessary libraries.

import cuequivariance as cue
import cuequivariance_jax as cuex
import jax
import jax.numpy as jnp
from cuequivariance import descriptors
from cuequivariance.experimental.mace import symmetric_contraction
from cuequivariance_jax.experimental.utils import MultiLayerPerceptron, gather

The input data consists of node features, edge vectors, radial embeddings, and sender and receiver indices.

num_species = 3
num_nodes = 12
num_edges = 26
vectors = cuex.randn(
    jax.random.key(0), cue.Irreps("O3", "1o"), (num_edges,), cue.ir_mul
)
node_feats = cuex.randn(
    jax.random.key(0), cue.Irreps("O3", "16x0e + 16x1o"), (num_nodes,), cue.ir_mul
)
node_species = jax.random.randint(jax.random.key(0), (num_nodes,), 0, num_species)
radial_embeddings = jax.random.normal(jax.random.key(0), (num_edges, 4))
senders, receivers = jax.random.randint(jax.random.key(0), (2, num_edges), 0, num_nodes)

def param(name: str, init_fn, shape, dtype):
    # dummy function to obtain parameters (when using flax, one should use self.param instead)
    print(f"param({name!r}, {init_fn!r}, {shape!r}, {dtype!r})")
    return init_fn(jax.random.key(0), shape, dtype)

Next, we define the layer’s hyperparameters.

num_features = 32
interaction_irreps = cue.Irreps("O3", "0e + 1o + 2e + 3o")
hidden_out = cue.Irreps("O3", "0e + 1o")
max_ell = 3
dtype = node_feats.dtype

The MACE layer is composed of two types of linear layers: those that depend on the species and those that do not.

def lin(irreps: cue.Irreps, input: cuex.IrrepsArray, name: str):
    e = descriptors.linear(input.irreps(), irreps)
    w = param(name, jax.random.normal, (e.inputs[0].irreps.dim,), dtype)
    return cuex.equivariant_tensor_product(e, w, input, precision="HIGH")


def linZ(irreps: cue.Irreps, input: cuex.IrrepsArray, name: str):
    e = descriptors.linear(input.irreps(), irreps)
    w = param(
        name,
        jax.random.normal,
        (num_species, e.inputs[0].irreps.dim),
        dtype,
    )
    return cuex.equivariant_tensor_product(
        e, w[node_species], input, precision="HIGH"
    ) / jnp.sqrt(num_species)

The first part involves operations before the convolutional part.

self_connection = linZ(num_features * hidden_out, node_feats, "linZ_skip_tp")
node_feats = lin(node_feats.irreps(), node_feats, "linear_up")
param('linZ_skip_tp', <function normal at 0x71db260e0cc0>, (3, 1024), dtype('float32'))
param('linear_up', <function normal at 0x71db260e0cc0>, (512,), dtype('float32'))

Next, we implement the convolutional part.

messages = node_feats[senders]
sph = cuex.spherical_harmonics(range(max_ell + 1), vectors)
e = descriptors.channelwise_tensor_product(messages.irreps(), sph.irreps(), interaction_irreps)
e = e.squeeze_modes().flatten_coefficient_modes()

mlp = MultiLayerPerceptron(
    [64, 64, 64, e.inputs[0].irreps.dim],
    jax.nn.silu,
    output_activation=False,
    with_bias=False,
)
w = mlp.init(jax.random.key(0), radial_embeddings)  # dummy parameters
mix = mlp.apply(w, radial_embeddings)

messages = cuex.equivariant_tensor_product(e, mix, messages, sph)

avg_num_neighbors = num_edges / num_nodes  # you should use a constant here
node_feats = gather(receivers, messages, node_feats.shape[0]) / avg_num_neighbors

Now, we perform the symmetric contraction part.

node_feats = lin(num_features * interaction_irreps, node_feats, "linear_down")
e, projection = symmetric_contraction(
    node_feats.irreps(),
    num_features * hidden_out,
    [1, 2, 3],
)
n = projection.shape[0]
w = param(
    "symmetric_contraction", jax.random.normal, (num_species, n, num_features), dtype
)
w = jnp.einsum("zau,ab->zbu", w, projection)
w = jnp.reshape(w, (num_species, -1))
node_feats = cuex.equivariant_tensor_product(e, w[node_species], node_feats)
param('linear_down', <function normal at 0x71db260e0cc0>, (5120,), dtype('float32'))
param('symmetric_contraction', <function normal at 0x71db260e0cc0>, (3, 86, 32), dtype('float32'))

Finally, we apply the remaining linear layers.

node_feats = lin(num_features * hidden_out, node_feats, "linear_post_sc")
node_feats = node_feats + self_connection  # [n_nodes, feature * hidden_irreps]

node_outputs = lin(cue.Irreps("O3", "0e"), node_feats, "linear_readout")

print(node_outputs)
param('linear_post_sc', <function normal at 0x71db260e0cc0>, (2048,), dtype('float32'))
param('linear_readout', <function normal at 0x71db260e0cc0>, (32,), dtype('float32'))
{1: 0e}
[[10.92605   ]
 [-2.8837109 ]
 [ 0.595831  ]
 [ 8.578955  ]
 [ 1.5201262 ]
 [ 4.8192573 ]
 [ 0.93503827]
 [ 0.54991317]
 [ 5.837841  ]
 [ 0.57969433]
 [19.84022   ]
 [ 0.13047403]]