SymmetricContraction#
- class cuequivariance_torch.SymmetricContraction(
- irreps_in: Irreps,
- irreps_out: Irreps,
- contraction_degree: int,
- num_elements: int,
- *,
- layout: IrrepsLayout | None = None,
- layout_in: IrrepsLayout | None = None,
- layout_out: IrrepsLayout | None = None,
- device: device | None = None,
- dtype: dtype | None = None,
- math_dtype: dtype | None = None,
- original_mace: bool = False,
- use_fallback: bool | None = None,
Accelerated implementation of the symmetric contraction operation introduced in https://arxiv.org/abs/2206.07697.
- Parameters:
irreps_in (Irreps) – The input irreps. All multiplicities (mul) within the irreps must be identical, indicating that each irrep appears the same number of times.
irreps_out (Irreps) – The output irreps. Similar to irreps_in, all multiplicities must be the same.
contraction_degree (int) – The degree of the symmetric contraction, specifying the maximum degree of the polynomial in the symmetric contraction.
num_elements (int) – The number of elements for the weight tensor.
layout (IrrepsLayout, optional) – The layout of the input and output irreps. If not provided, a default layout is used.
math_dtype (torch.dtype, optional) – The data type for mathematical operations. If not specified, the default data type from the torch environment is used.
use_fallback (bool, optional) – If None (default), a CUDA kernel will be used if available. If False, a CUDA kernel will be used, and an exception is raised if it’s not available. If True, a PyTorch fallback method is used regardless of CUDA kernel availability.
Examples
>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") >>> irreps_in = cue.Irreps("O3", "32x0e + 32x1o") >>> irreps_out = cue.Irreps("O3", "32x0e") >>> layer = SymmetricContraction(irreps_in, irreps_out, contraction_degree=3, num_elements=5, layout=cue.ir_mul, dtype=torch.float32, device=device)
Now layer can be used as part of a PyTorch model.
The argument original_mace can be set to True to emulate the original MACE implementation.
>>> feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e") >>> target_irreps = cue.Irreps("O3", "32x0e + 32x1o") >>> # OLD FUNCTION DEFINITION: >>> # symmetric_contractions_old = SymmetricContraction( >>> # irreps_in=feats_irreps, >>> # irreps_out=target_irreps, >>> # correlation=3, >>> # num_elements=10, >>> # ) >>> # NEW FUNCTION DEFINITION: >>> symmetric_contractions_new = cuet.SymmetricContraction( ... irreps_in=feats_irreps, ... irreps_out=target_irreps, ... contraction_degree=3, ... num_elements=10, ... layout_in=cue.ir_mul, ... layout_out=cue.mul_ir, ... original_mace=True, ... dtype=torch.float64, ... device=device, ... )
Then the execution is as follows:
>>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64, device=device) >>> # with node_attrs_index being the index version of node_attrs, sth like: >>> # node_attrs_index = torch.nonzero(node_attrs)[:, 1].int() >>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32, device=device) >>> # OLD CALL: >>> # symmetric_contractions_old(node_feats, node_attrs) >>> # NEW CALL: >>> node_feats = torch.transpose(node_feats, 1, 2).flatten(1) >>> symmetric_contractions_new(node_feats, node_attrs_index) tensor([[...)
Note
The term ‘mul’ refers to the multiplicity of an irrep, indicating how many times it appears in the representation. This layer requires that all input and output irreps have the same multiplicity for the symmetric contraction operation to be well-defined.
Forward Pass
- forward( ) Tensor #
Perform the forward pass of the symmetric contraction operation.
- Parameters:
x (torch.Tensor) – The input tensor. It should have shape (batch, irreps_in.dim).
indices (torch.Tensor) – The index of the weight to use for each batch element. It should have shape (batch,).
- Returns:
The output tensor. It has shape (batch, irreps_out.dim).
- Return type: