MACE Operations#
In this notebook, we will go through the blocks that make up the MACE architecture, and how they can be accelerated through the use of cuEquivariance
There are 4 operations in MACE that act on irreps:
Channelwise tensor product
Symmetric contraction
Linear layers
Indexed linear layers
Before we go through them one by one, a small remark about the structure of the irreps is necessary.
A note on layouts#
Let us consider a collection of 10 \(l=1\) objects, or vectors.
In e3nn, this would be stored as a \(10\times3\) tensor with the \((x,y,z)\) components of each vector contiguous. This is what we refer to as the mul_ir layout.
For performance reasons, in cuEquivariance we adopt the transpose of this layout, i.e. the same object would correspond to a \(3\times10\) object, with all the \(x\) terms contiguous.
In the following, we will use the preferred layout for each library, but it must be noted that the transposition operation can be quite expensive, so that adhering to the correct layout throughout your code will result in the best performance.
Since any e3nn operation can be performed in cuEquivariance, this should in general be always possible.
Let us now start by importing a few useful packages
import torch
import numpy as np
from typing import Tuple, List
import time
from e3nn import o3
import cuequivariance as cue
import cuequivariance_torch as cuet
from cuequivariance.group_theory.experimental.e3nn import O3_e3nn
Channelwise tensor product#
This is the main operation performed on the edges in a MACE model, typically found in the InteractionBlock modules.
It consists in the tensor product between the features of each neighbor and the spherical harmonics representing the edge, but it is computed in a _”channel-wise”_ fashion, in the sense that the neighbor’s channels are not mixed.
The original implementation in e3nn makes use of a custom tensor product (the following code is adapted from the MACE repository):
# Parameters
multiplicity = 128
num_nodes = 1000
num_edges = 10000
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
irreps_node_input = o3.Irreps(f"{multiplicity}x0e + {multiplicity}x1o")
irreps_edge_attr = o3.Irreps("1x0e + 1x1o")
target_irreps = irreps_edge_attr
# Create the instructions
irreps_mid, instructions = tp_out_irreps_with_instructions(
irreps_node_input,
irreps_edge_attr,
target_irreps,
)
# Create the TP module
conv_tp = o3.TensorProduct(
irreps_node_input,
irreps_edge_attr,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False
).to(device)
# Create input tensors
node_feats = torch.randn(num_nodes, irreps_node_input.dim, device=device, dtype=dtype)
senders = torch.randint(0, num_nodes, (num_edges,), device=device, dtype=torch.int64)
receivers = torch.randint(0, num_nodes, (num_edges,), device=device, dtype=torch.int64)
edge_attrs = torch.randn(num_edges, irreps_edge_attr.dim, device=device, dtype=dtype)
weights = torch.randn(num_edges, conv_tp.weight_numel, device=device, dtype=dtype)
# Perform TP
mji = conv_tp(
node_feats[senders], edge_attrs, weights
) # [num_nodes, irreps]
# Perform scatter
m_tmp = torch.zeros(num_nodes, irreps_mid.dim, device=device, dtype=dtype)
message = m_tmp.scatter_add(0, receivers.unsqueeze(-1).expand_as(mji), mji)
# Output shape
print("Output shape:", message.shape)
Output shape: torch.Size([1000, 1024])
As you can see, besides the TensorProduct itself, this requires gathering all node features corresponding to the edges (node_feats[senders]), and scattering the output back to the correct nodes.
In cuEquivariance, not only we can perform the TP, but we can also perform the gather/scatter operations in a single call. For this operation, we will use our uniform_1d kernel, since there is a single set of irreps in the channelwise structure.
Let’s do this explicitly, then we will show a premade module just for this operation.
For more information abou buildingt the descriptor itself, you can refer to the definition of cue.descriptors.channelwise_tensor_product.
# Cue version of the irreps
irreps_in1 = cue.Irreps("O3", irreps_node_input)
irreps_in2 = cue.Irreps("O3", irreps_edge_attr)
irreps_out = cue.Irreps("O3", target_irreps)
# Defining the operation
e = cue.descriptors.channelwise_tensor_product(
irreps_in1, irreps_in2, irreps_out
)
# The TP itself:
cue_tp = cuet.SegmentedPolynomial(
e.polynomial,
method="uniform_1d"
).to(device)
# Transposing inputs layout:
cue_node_feats = cuet.TransposeIrrepsLayout(
irreps_in1,
source=cue.mul_ir,
target=cue.ir_mul,
device=device,
use_fallback=device=="cpu",
)(node_feats)
cue_edge_attrs = cuet.TransposeIrrepsLayout(
irreps_in2,
source=cue.mul_ir,
target=cue.ir_mul,
device=device,
use_fallback=device=="cpu",
)(edge_attrs)
# Performing the TP
cue_message = cue_tp(
[weights, cue_node_feats, cue_edge_attrs],
input_indices={1: senders}, # indices for cue_node_feats
output_shapes={0: cue_node_feats}, # We only care about the first dimension being num_nodes
output_indices={0: receivers}, # Indices for the output
)
print("Output shape:", cue_message[0].shape)
# Transposing the output
cue_message_transp = cuet.TransposeIrrepsLayout(
e.outputs[0].irreps,
source=cue.ir_mul,
target=cue.mul_ir,
device=device,
use_fallback=device=="cpu",
)(cue_message[0])
# Comparing the result
print("Results match:", torch.allclose(message, cue_message_transp, atol=1e-5))
Output shape: torch.Size([1000, 1024])
Results match: True
Alternatively, we can use the premade function for this particular tensor product:
# Defining TP through the premade block
cue_cw = cuet.ChannelWiseTensorProduct(
irreps_in1,
irreps_in2,
irreps_out,
layout=cue.ir_mul,
shared_weights=False,
internal_weights=False,
device=device
)
# Performing the TP
cue_cw_message = cue_cw(
cue_node_feats,
cue_edge_attrs,
weights,
indices_1=senders,
indices_out=receivers,
size_out=num_nodes
)
# Transposing
cue_cw_message_transp = cuet.TransposeIrrepsLayout(
e.outputs[0].irreps,
source=cue.ir_mul,
target=cue.mul_ir,
device=device,
use_fallback=device=="cpu",
)(cue_cw_message)
# Comparing the results
print("Results match:", torch.allclose(message, cue_cw_message_transp, atol=1e-5))
Results match: True
We can also compare the speed of the two approaches (in their respective layouts):
throwaway = 10
repetitions = 1000 if device=="cuda" else 10
e3nn_times = []
for _ in range(throwaway):
mji = conv_tp(node_feats[senders], edge_attrs, weights)
m_tmp = torch.zeros(num_nodes, irreps_mid.dim, device=device, dtype=dtype)
message = m_tmp.scatter_add(0, receivers.unsqueeze(-1).expand_as(mji), mji)
for _ in range(repetitions):
torch.cuda.synchronize()
t1 = time.perf_counter()
mji = conv_tp(node_feats[senders], edge_attrs, weights)
m_tmp = torch.zeros(num_nodes, irreps_mid.dim, device=device, dtype=dtype)
message = m_tmp.scatter_add(0, receivers.unsqueeze(-1).expand_as(mji), mji)
torch.cuda.synchronize()
e3nn_times.append(time.perf_counter()-t1)
cuet_times = []
for _ in range(throwaway):
cue_message = cue_tp(
[weights, cue_node_feats, cue_edge_attrs],
input_indices={1: senders},
output_shapes={0: cue_node_feats},
output_indices={0: receivers},
)
for _ in range(repetitions):
torch.cuda.synchronize()
t1 = time.perf_counter()
cue_message = cue_tp(
[weights, cue_node_feats, cue_edge_attrs],
input_indices={1: senders},
output_shapes={0: cue_node_feats},
output_indices={0: receivers},
)
torch.cuda.synchronize()
cuet_times.append(time.perf_counter()-t1)
e3nn_avg = 1000*np.mean(e3nn_times)
cuet_avg = 1000*np.mean(cuet_times)
print(f"e3nn time: {e3nn_avg:.2} ms")
print(f"Cuequivariance time: {cuet_avg:.2} ms")
print(f"Speedup: {e3nn_avg/cuet_avg:.2}x")
e3nn time: 0.98 ms
Cuequivariance time: 0.17 ms
Speedup: 5.6x
Of course a true comparison would require to see the module used in a real model, and our kernels tend to have best performance for very large input sizes, but even from this simple example it is clear that cuEquivariance offers a very good speedup for this operation.
Of course the backwards and double-backward pass are also supported and accelerated, but they will not be shown in this example.
Symmetric Contraction#
The Symmetric Contraction is the most distinctive TP in MACE. It consists of a tensor product with a single input that gets contracted with itself multiple times. It is typically used in the EquivariantProductBasisBlock.
As in the previous case, we will first consider the original MACE implementation:
# Parameters
num_species = 10
multiplicity = 128
correlation = 3
num_nodes = 1000
dtype = torch.float32
irreps_in = o3.Irreps(f"{multiplicity}x0e + {multiplicity}x1o + {multiplicity}x2e + {multiplicity}x3o")
irreps_out = o3.Irreps(f"{multiplicity}x0e + {multiplicity}x1o")
# Define operation
sc = SymmetricContraction(
irreps_in,
irreps_out,
correlation=correlation,
num_elements=num_species
).to(dtype).to(device)
# Create inputs
node_feats = torch.randn(num_nodes, multiplicity, irreps_in.dim // multiplicity, device=device, dtype=dtype)
species = torch.randint(0, num_species, (num_nodes,), device=device, dtype=torch.int64)
species_1hot = torch.nn.functional.one_hot(species, num_species).to(dtype).to(device)
# Perform operation
out_feats = sc(node_feats, species_1hot)
# Output shape
print("Output shape:", out_feats.shape)
Output shape: torch.Size([1000, 512])
We can now perform the same operation using the corresponding cuEquivariance module (you can check the module definition to see the descriptor utilized inside).
While the original module needs a 1-hot version of the atomic species, we use the species index directly and can perform more efficient operations.
Please note that in order to match the weights used in the previous implementation we will need to manually manipulate the internal weights of the system. In a native scenario, however, the weights can of course be used as they are. We also need to use the O3_e3nn group for compatibility, but the standard “O3” would work for the general case.
cue_irreps_in = cue.Irreps(O3_e3nn, irreps_in)
cue_irreps_out = cue.Irreps(O3_e3nn, irreps_out)
# The SC module
cue_sc = cuet.SymmetricContraction(
cue_irreps_in,
cue_irreps_out,
contraction_degree=correlation,
num_elements=num_species,
layout_in=cue.ir_mul,
layout_out=cue.ir_mul,
original_mace=True,
device=device,
dtype=dtype,
)
# Modifying the weights by hand
cue_sc.weight.data = torch.concatenate([x for x in sc.parameters()], dim=1)
# The input in this case is close to the needed shape:
cue_node_feats = torch.transpose(node_feats, 1, 2).flatten(1)
cue_out_feats = cue_sc(cue_node_feats, species)
print("Output shape:", cue_out_feats.shape)
# Transposing the output
cue_out_feats_transp = cuet.TransposeIrrepsLayout(
cue_irreps_out,
source=cue.ir_mul,
target=cue.mul_ir,
device=device,
use_fallback=device=="cpu",
)(cue_out_feats)
# Comparing the result
print("Results match:", torch.allclose(out_feats, cue_out_feats_transp, atol=1e-5))
Output shape: torch.Size([1000, 512])
Results match: True
Here too we can compare the speed of the two approaches:
throwaway = 10 if device=="cuda" else 1
repetitions = 100 if device=="cuda" else 2
e3nn_times = []
for _ in range(throwaway):
out_feats = sc(node_feats, species_1hot)
for _ in range(repetitions):
torch.cuda.synchronize()
t1 = time.perf_counter()
out_feats = sc(node_feats, species_1hot)
torch.cuda.synchronize()
e3nn_times.append(time.perf_counter()-t1)
cuet_times = []
for _ in range(throwaway):
cue_out_feats = cue_sc(cue_node_feats, species)
for _ in range(repetitions):
torch.cuda.synchronize()
t1 = time.perf_counter()
cue_out_feats = cue_sc(cue_node_feats, species)
torch.cuda.synchronize()
cuet_times.append(time.perf_counter()-t1)
e3nn_avg = 1000*np.mean(e3nn_times)
cuet_avg = 1000*np.mean(cuet_times)
print(f"e3nn time: {e3nn_avg:.3} ms")
print(f"Cuequivariance time: {cuet_avg:.3} ms")
print(f"Speedup: {e3nn_avg/cuet_avg:.3}x")
e3nn time: 40.6 ms
Cuequivariance time: 3.54 ms
Speedup: 11.5x
Linear layers#
The linear layers are the most basic e3nn operation, used in several blocks in MACE.
While we do not provide a large speedup for this operation, we can perform natively in the ir_mul layout, for use in a complete cuEquivariance pipeline.
Let us start again from the original implementation:
# Parameters
multiplicity = 128
num_nodes = 10000
dtype = torch.float32
irreps_in = o3.Irreps(f"{multiplicity}x0e + {multiplicity}x1o")
irreps_out = o3.Irreps(f"{multiplicity}x0e + {multiplicity}x1o")
# Define operation
lin = o3.Linear(
irreps_in,
irreps_out,
).to(dtype).to(device)
# Create inputs
in_feats = torch.randn(num_nodes, irreps_in.dim, device=device, dtype=dtype)
# Perform operation
out_feats = lin(in_feats)
# Output shape
print("Output shape:", out_feats.shape)
Output shape: torch.Size([10000, 512])
And the equivalent cuEquivariance code:
cue_irreps_in = cue.Irreps("O3", irreps_in)
cue_irreps_out = cue.Irreps("O3", irreps_out)
# The linear module
cue_lin = cuet.Linear(
cue_irreps_in,
cue_irreps_out,
internal_weights=False,
layout=cue.ir_mul,
device=device,
dtype=dtype,
)
# Transposing the input
cue_in_feats = cuet.TransposeIrrepsLayout(
cue_irreps_out,
source=cue.mul_ir,
target=cue.ir_mul,
device=device,
use_fallback=device=="cpu",
)(in_feats)
cue_out_feats = cue_lin(cue_in_feats, weight=lin.weight.unsqueeze(0))
print("Output shape:", cue_out_feats.shape)
# Transposing the output
cue_out_feats_transp = cuet.TransposeIrrepsLayout(
cue_irreps_out,
source=cue.ir_mul,
target=cue.mul_ir,
device=device,
use_fallback=device=="cpu",
)(cue_out_feats)
# Comparing the result
print("Results match:", torch.allclose(out_feats, cue_out_feats_transp, atol=1e-5))
Output shape: torch.Size([10000, 512])
Results match: True
Here too the results match.
We can compare the speed, although the difference will not be large in this case.
throwaway = 10
repetitions = 1000 if device=="cuda" else 10
e3nn_times = []
for _ in range(throwaway):
out_feats = lin(in_feats)
for _ in range(repetitions):
torch.cuda.synchronize()
t1 = time.perf_counter()
out_feats = lin(in_feats)
torch.cuda.synchronize()
e3nn_times.append(time.perf_counter()-t1)
cuet_times = []
for _ in range(throwaway):
cue_lin(cue_in_feats, weight=lin.weight.unsqueeze(0))
for _ in range(repetitions):
torch.cuda.synchronize()
t1 = time.perf_counter()
cue_lin(cue_in_feats, weight=lin.weight.unsqueeze(0))
torch.cuda.synchronize()
cuet_times.append(time.perf_counter()-t1)
e3nn_avg = 1000*np.mean(e3nn_times)
cuet_avg = 1000*np.mean(cuet_times)
print(f"e3nn time: {e3nn_avg:.3} ms")
print(f"Cuequivariance time: {cuet_avg:.3} ms")
print(f"Speedup: {e3nn_avg/cuet_avg:.3}x")
e3nn time: 0.577 ms
Cuequivariance time: 0.521 ms
Speedup: 1.11x
Skip_tp or Indexed Linear#
The last operation is an operation used in MACE in the InteractionBlock and typically called skip_tp, as it is used as a skip connection.
However, in the context of cuEquivariance we will typically refer to this operation as indexed linear, as it consists of a linear operation where the weight matrix is indexed on the species of each input.
We will first present the original implementation, which makes use of an expensive FullyConnectedTensorProduct.
# Parameters
num_species = 20
multiplicity = 128
num_nodes = 10000
dtype = torch.float32
irreps_in = o3.Irreps(f"{multiplicity}x0e + {multiplicity}x1o")
attr_irreps = o3.Irreps(f"{num_species}x0e")
irreps_out = o3.Irreps(f"{multiplicity}x0e + {multiplicity}x1o")
# Define operation
skip_tp = o3.FullyConnectedTensorProduct(
irreps_in,
attr_irreps,
irreps_out,
).to(dtype).to(device)
# Create inputs
in_feats = torch.randn(num_nodes, irreps_in.dim, device=device, dtype=dtype)
species = torch.randint(0, num_species, (num_nodes,), device=device, dtype=torch.int64)
species, _ = torch.sort(species)
species_1hot = torch.nn.functional.one_hot(species, num_species).to(dtype).to(device)
# Perform operation
out_feats = skip_tp(in_feats, species_1hot)
# Output shape
print("Output shape:", out_feats.shape)
Output shape: torch.Size([10000, 512])
We will now show the equivalent cuEquivariance implementation that makes use of a Linear block and its indexing capabilities.
We will show the use of two different backends: naive and indexed_linear. While the first can work in any setting, the second can only be used when the atomic species are sorted. However, it offers much better performance.
cue_irreps_in = cue.Irreps("O3", irreps_in)
cue_irreps_out = cue.Irreps("O3", irreps_out)
# The linear module
cue_lin = cuet.Linear(
cue_irreps_in,
cue_irreps_out,
internal_weights=False,
weight_classes=num_species,
layout=cue.ir_mul,
device=device,
dtype=dtype,
method='naive'
)
# The faster linear module
cue_indexed_lin = cuet.Linear(
cue_irreps_in,
cue_irreps_out,
internal_weights=False,
weight_classes=num_species,
layout=cue.ir_mul,
device=device,
dtype=dtype,
method='indexed_linear' if device=="cuda" else "naive"
)
# Transposing the input
cue_in_feats = cuet.TransposeIrrepsLayout(
cue_irreps_out,
source=cue.mul_ir,
target=cue.ir_mul,
device=device,
use_fallback=device=="cpu",
)(in_feats)
# Rearranging the weights by hand
cue_weight = skip_tp.weight.reshape(2*multiplicity, num_species, multiplicity
).transpose(0,1).reshape(num_species, -1)/np.sqrt(num_species)
# Performing the operation
cue_out_feats = cue_lin(cue_in_feats, weight=cue_weight, weight_indices=species)
print("Output shape:", cue_out_feats.shape)
# Transposing the output
cue_out_feats_transp = cuet.TransposeIrrepsLayout(
cue_irreps_out,
source=cue.ir_mul,
target=cue.mul_ir,
device=device,
use_fallback=device=="cpu",
)(cue_out_feats)
# Comparing the result
print("Results match:", torch.allclose(out_feats, cue_out_feats_transp, atol=1e-3))
# Performing the operation with the other backend
cue_out_feats = cue_indexed_lin(cue_in_feats, weight=cue_weight, weight_indices=species)
print("Output shape:", cue_out_feats.shape)
# Transposing the output
cue_out_feats_transp = cuet.TransposeIrrepsLayout(
cue_irreps_out,
source=cue.ir_mul,
target=cue.mul_ir,
device=device,
use_fallback=device=="cpu",
)(cue_out_feats)
# Comparing the result
print("Results match:", torch.allclose(out_feats, cue_out_feats_transp, atol=1e-3))
Output shape: torch.Size([10000, 512])
Results match: True
Output shape: torch.Size([10000, 512])
Results match: True
And we can compare the speed for the two implementations:
throwaway = 10
repetitions = 100 if device=="cuda" else 10
e3nn_times = []
for _ in range(throwaway):
out_feats = skip_tp(in_feats, species_1hot)
for _ in range(repetitions):
torch.cuda.synchronize()
t1 = time.perf_counter()
out_feats = skip_tp(in_feats, species_1hot)
torch.cuda.synchronize()
e3nn_times.append(time.perf_counter()-t1)
cuet_times = []
for _ in range(throwaway):
cue_out_feats = cue_lin(cue_in_feats, weight=cue_weight, weight_indices=species)
for _ in range(repetitions):
torch.cuda.synchronize()
t1 = time.perf_counter()
cue_out_feats = cue_lin(cue_in_feats, weight=cue_weight, weight_indices=species)
torch.cuda.synchronize()
cuet_times.append(time.perf_counter()-t1)
cuet_v2_times = []
for _ in range(throwaway):
cue_out_feats = cue_indexed_lin(cue_in_feats, weight=cue_weight, weight_indices=species)
for _ in range(repetitions):
torch.cuda.synchronize()
t1 = time.perf_counter()
cue_out_feats = cue_indexed_lin(cue_in_feats, weight=cue_weight, weight_indices=species)
torch.cuda.synchronize()
cuet_v2_times.append(time.perf_counter()-t1)
e3nn_avg = 1000*np.mean(e3nn_times)
cuet_avg = 1000*np.mean(cuet_times)
cuet_v2_avg = 1000*np.mean(cuet_v2_times)
print(f"e3nn time: {e3nn_avg:.3} ms")
print(f"Cuequivariance naive time: {cuet_avg:.3} ms")
print(f"Speedup: {e3nn_avg/cuet_avg:.3}x")
print(f"Cuequivariance indexed linear time: {cuet_v2_avg:.3} ms")
print(f"Speedup: {e3nn_avg/cuet_v2_avg:.3}x")
e3nn time: 9.01 ms
Cuequivariance naive time: 10.4 ms
Speedup: 0.869x
Cuequivariance indexed linear time: 1.02 ms
Speedup: 8.81x
As you can see, by using the best kernel we can achieve a very good speedup also in this case.
By using all of these modules, it is possible to accelerate a model like MACE up to 10 times, depending on the model and input size.
These operations are supported by the official implementation of MACE.