FullyConnectedTensorProduct#
- class cuequivariance_torch.FullyConnectedTensorProduct#
Fully connected tensor product layer.
- Parameters:
irreps_in1 (Irreps) – Input irreps for the first operand.
irreps_in2 (Irreps) – Input irreps for the second operand.
irreps_out (Irreps) – Output irreps.
layout (IrrepsLayout, optional) – The layout of the input and output irreps. Default is
cue.mul_ir
which is the layout corresponding to e3nn.layout_in1 (IrrepsLayout, optional) – The layout of the first input irreducible representations, by default
layout
.layout_in2 (IrrepsLayout, optional) – The layout of the second input irreducible representations, by default
layout
.layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default
layout
.shared_weights (bool, optional) – Whether to share weights across the batch dimension. Default is True.
internal_weights (bool, optional) – Whether to create module parameters for weights. Default is None.
device (torch.device, optional) – The device to use for the operation.
dtype (torch.dtype, optional) – The dtype to use for the operation weights, by default the torch default dtype.
math_dtype (torch.dtype, optional) – The dtype to use for the math operations, by default dtype.
method (str, optional) – The method to use for the linear layer, by default “fused_tp” (using a CUDA kernel).
use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True, the “naive” method is used. If False or None (default), the “fused_tp” method is used.
Note
In e3nn there was a irrep_normalization and path_normalization parameters. This module currently only supports “component” irrep normalization and “element” path normalization.
Forward Pass
- forward(x1, x2, weight=None)#
Perform the forward pass of the fully connected tensor product operation.
- Parameters:
x1 (torch.Tensor) – Input tensor for the first operand. It should have the shape (batch_size, irreps_in1.dim).
x2 (torch.Tensor) – Input tensor for the second operand. It should have the shape (batch_size, irreps_in2.dim).
weight (torch.Tensor, optional) – Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (weight_numel,) if shared_weights is True. If None, the internal weights are used.
- Returns:
Output tensor resulting from the fully connected tensor product operation. It will have the shape (batch_size, irreps_out.dim).
- Return type:
- Raises:
ValueError – If internal weights are used and weight is not None, or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor.
- __init__(
- irreps_in1,
- irreps_in2,
- irreps_out,
- *,
- layout=None,
- layout_in1=None,
- layout_in2=None,
- layout_out=None,
- shared_weights=True,
- internal_weights=None,
- device=None,
- dtype=None,
- math_dtype=None,
- use_fallback=None,
- method=None,
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
irreps_in1 (Irreps)
irreps_in2 (Irreps)
irreps_out (Irreps)
layout (IrrepsLayout | None)
layout_in1 (IrrepsLayout | None)
layout_in2 (IrrepsLayout | None)
layout_out (IrrepsLayout | None)
shared_weights (bool)
internal_weights (bool)
device (device | None)
dtype (dtype | None)
math_dtype (dtype | None)
use_fallback (bool | None)
method (str | None)
- classmethod __new__(*args, **kwargs)#