FullyConnectedTensorProduct#
- class cuequivariance_torch.FullyConnectedTensorProduct#
Fully connected tensor product layer.
- Parameters:
irreps_in1 (Irreps) – Input irreps for the first operand.
irreps_in2 (Irreps) – Input irreps for the second operand.
irreps_out (Irreps) – Output irreps.
layout (IrrepsLayout, optional) – The layout of the input and output irreps. Default is
cue.mul_ir
which is the layout corresponding to e3nn.layout_in1 (IrrepsLayout, optional) – The layout of the first input irreducible representations, by default
layout
.layout_in2 (IrrepsLayout, optional) – The layout of the second input irreducible representations, by default
layout
.layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default
layout
.shared_weights (bool, optional) – Whether to share weights across the batch dimension. Default is True.
internal_weights (bool, optional) – Whether to create module parameters for weights. Default is None.
device (torch.device, optional) – The device to use for the operation.
dtype (torch.dtype, optional) – The dtype to use for the operation weights, by default the torch default dtype.
math_dtype (torch.dtype or string, optional) – The dtype to use for the math operations, by default it follows the dtype of the input tensors, if possible, or the torch default dtype (see SegmentedPolynomial for more details).
method (str, optional) – The method to use for the linear layer, by default “fused_tp” (using a CUDA kernel).
use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True, the “naive” method is used. If False or None (default), the “fused_tp” method is used.
Note
In e3nn there was a irrep_normalization and path_normalization parameters. This module currently only supports “component” irrep normalization and “element” path normalization.
Forward Pass
- forward(x1, x2, weight=None)#
Perform the forward pass of the fully connected tensor product operation.
- Parameters:
x1 (torch.Tensor) – Input tensor for the first operand. It should have the shape (batch_size, irreps_in1.dim).
x2 (torch.Tensor) – Input tensor for the second operand. It should have the shape (batch_size, irreps_in2.dim).
weight (torch.Tensor, optional) – Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (weight_numel,) if shared_weights is True. If None, the internal weights are used.
- Returns:
Output tensor resulting from the fully connected tensor product operation. It will have the shape (batch_size, irreps_out.dim).
- Return type:
- Raises:
ValueError – If internal weights are used and weight is not None, or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor.
- __init__(
- irreps_in1,
- irreps_in2,
- irreps_out,
- *,
- layout=None,
- layout_in1=None,
- layout_in2=None,
- layout_out=None,
- shared_weights=True,
- internal_weights=None,
- device=None,
- dtype=None,
- math_dtype=None,
- use_fallback=None,
- method=None,
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
irreps_in1 (Irreps)
irreps_in2 (Irreps)
irreps_out (Irreps)
layout (IrrepsLayout | None)
layout_in1 (IrrepsLayout | None)
layout_in2 (IrrepsLayout | None)
layout_out (IrrepsLayout | None)
shared_weights (bool)
internal_weights (bool)
device (device | None)
dtype (dtype | None)
use_fallback (bool | None)
method (str | None)
- classmethod __new__(*args, **kwargs)#