FullyConnectedTensorProduct#

class cuequivariance_torch.FullyConnectedTensorProduct#

Fully connected tensor product layer.

Parameters:
  • irreps_in1 (Irreps) – Input irreps for the first operand.

  • irreps_in2 (Irreps) – Input irreps for the second operand.

  • irreps_out (Irreps) – Output irreps.

  • layout (IrrepsLayout, optional) – The layout of the input and output irreps. Default is cue.mul_ir which is the layout corresponding to e3nn.

  • layout_in1 (IrrepsLayout, optional) – The layout of the first input irreducible representations, by default layout.

  • layout_in2 (IrrepsLayout, optional) – The layout of the second input irreducible representations, by default layout.

  • layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default layout.

  • shared_weights (bool, optional) – Whether to share weights across the batch dimension. Default is True.

  • internal_weights (bool, optional) – Whether to create module parameters for weights. Default is None.

  • device (torch.device, optional) – The device to use for the operation.

  • dtype (torch.dtype, optional) – The dtype to use for the operation weights, by default the torch default dtype.

  • math_dtype (torch.dtype, optional) – The dtype to use for the math operations, by default dtype.

  • method (str, optional) – The method to use for the linear layer, by default “fused_tp” (using a CUDA kernel).

  • use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True, the “naive” method is used. If False or None (default), the “fused_tp” method is used.

Note

In e3nn there was a irrep_normalization and path_normalization parameters. This module currently only supports “component” irrep normalization and “element” path normalization.

Forward Pass

forward(x1, x2, weight=None)#

Perform the forward pass of the fully connected tensor product operation.

Parameters:
  • x1 (torch.Tensor) – Input tensor for the first operand. It should have the shape (batch_size, irreps_in1.dim).

  • x2 (torch.Tensor) – Input tensor for the second operand. It should have the shape (batch_size, irreps_in2.dim).

  • weight (torch.Tensor, optional) – Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (weight_numel,) if shared_weights is True. If None, the internal weights are used.

Returns:

Output tensor resulting from the fully connected tensor product operation. It will have the shape (batch_size, irreps_out.dim).

Return type:

torch.Tensor

Raises:

ValueError – If internal weights are used and weight is not None, or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor.

__init__(
irreps_in1,
irreps_in2,
irreps_out,
*,
layout=None,
layout_in1=None,
layout_in2=None,
layout_out=None,
shared_weights=True,
internal_weights=None,
device=None,
dtype=None,
math_dtype=None,
use_fallback=None,
method=None,
)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
classmethod __new__(*args, **kwargs)#