SymmetricContraction#

class cuequivariance_torch.SymmetricContraction#

Accelerated implementation of the symmetric contraction operation introduced in https://arxiv.org/abs/2206.07697.

Parameters:
  • irreps_in (Irreps) – The input irreps. All multiplicities (mul) within the irreps must be identical, indicating that each irrep appears the same number of times.

  • irreps_out (Irreps) – The output irreps. Similar to irreps_in, all multiplicities must be the same.

  • contraction_degree (int) – The degree of the symmetric contraction, specifying the maximum degree of the polynomial in the symmetric contraction.

  • num_elements (int) – The number of elements for the weight tensor.

  • layout (IrrepsLayout, optional) – The layout of the input and output irreps. If not provided, a default layout is used.

  • layout_in (IrrepsLayout, optional) – The layout of the input irreducible representations, by default layout.

  • layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default layout.

  • device (torch.device, optional) – The device to use for the operation.

  • dtype (torch.dtype, optional) – The dtype to use for the operation weights, by default torch.float32.

  • math_dtype (torch.dtype, optional) – The dtype to use for the math operations, by default it follows the dtype of the input tensors.

  • original_mace (bool, optional) – Whether to use the original MACE implementation, by default False.

  • method (str, optional) – The method to use for the operation, by default “uniform_1d” (using a CUDA kernel) if all segments have the same shape, otherwise “naive” (using a PyTorch implementation).

  • use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True the “naive” method is used. If False the “uniform_1d” method is used (make sure all segments have the same shape).

Examples

>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
>>> irreps_in = cue.Irreps("O3", "32x0e + 32x1o")
>>> irreps_out = cue.Irreps("O3", "32x0e")
>>> layer = SymmetricContraction(irreps_in, irreps_out, contraction_degree=3, num_elements=5, layout=cue.ir_mul, dtype=torch.float32, device=device)

The argument original_mace can be set to True to emulate the original MACE implementation.

>>> feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e")
>>> target_irreps = cue.Irreps("O3", "32x0e + 32x1o")
>>> # OLD FUNCTION DEFINITION:
>>> # symmetric_contractions_old = SymmetricContraction(
>>> #     irreps_in=feats_irreps,
>>> #     irreps_out=target_irreps,
>>> #     correlation=3,
>>> #     num_elements=10,
>>> # )
>>> # NEW FUNCTION DEFINITION:
>>> symmetric_contractions_new = cuet.SymmetricContraction(
...     irreps_in=feats_irreps,
...     irreps_out=target_irreps,
...     contraction_degree=3,
...     num_elements=10,
...     layout_in=cue.ir_mul,
...     layout_out=cue.mul_ir,
...     original_mace=True,
...     dtype=torch.float64,
...     device=device,
... )

Then the execution is as follows:

>>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64, device=device)
>>> # with node_attrs_index being the index version of node_attrs, sth like:
>>> # node_attrs_index = torch.nonzero(node_attrs)[:, 1].int()
>>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32, device=device)
>>> # OLD CALL:
>>> # symmetric_contractions_old(node_feats, node_attrs)
>>> # NEW CALL:
>>> node_feats = torch.transpose(node_feats, 1, 2).flatten(1)
>>> symmetric_contractions_new(node_feats, node_attrs_index)
tensor([[...)

Note

The term ‘mul’ refers to the multiplicity of an irrep, indicating how many times it appears in the representation. This layer requires that all input and output irreps have the same multiplicity for the symmetric contraction operation to be well-defined.

Forward Pass

forward(x, indices)#

Perform the forward pass of the symmetric contraction operation.

Parameters:
  • x (torch.Tensor) – The input tensor. It should have shape (batch, irreps_in.dim).

  • indices (torch.Tensor) – The index of the weight to use for each batch element. It should have shape (batch,).

Returns:

The output tensor. It has shape (batch, irreps_out.dim).

Return type:

torch.Tensor

__init__(
irreps_in,
irreps_out,
contraction_degree,
num_elements,
*,
layout=None,
layout_in=None,
layout_out=None,
device=None,
dtype=None,
math_dtype=None,
original_mace=False,
use_fallback=None,
method=None,
)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
classmethod __new__(*args, **kwargs)#