ChannelWiseTensorProduct#
- class cuequivariance_torch.ChannelWiseTensorProduct(
- irreps_in1: Irreps,
- irreps_in2: Irreps,
- filter_irreps_out: Sequence[Irrep] = None,
- *,
- layout: IrrepsLayout | None = None,
- layout_in1: IrrepsLayout | None = None,
- layout_in2: IrrepsLayout | None = None,
- layout_out: IrrepsLayout | None = None,
- shared_weights: bool = True,
- internal_weights: bool = None,
- device: device | None = None,
- dtype: dtype | None = None,
- math_dtype: dtype | None = None,
- use_fallback: bool | None = None,
Channel-wise tensor product layer.
- Parameters:
irreps_in1 (Irreps) – Input irreps for the first operand.
irreps_in2 (Irreps) – Input irreps for the second operand.
filter_irreps_out (Sequence of Irrep, optional) – Filter for the output irreps. Default is None.
layout (IrrepsLayout, optional) – The layout of the input and output irreps. Default is
cue.mul_ir
which is the layout corresponding to e3nn.shared_weights (bool, optional) – Whether to share weights across the batch dimension. Default is True.
internal_weights (bool, optional) – Whether to create module parameters for weights. Default is None.
use_fallback (bool, optional) – If None (default), a CUDA kernel will be used if available. If False, a CUDA kernel will be used, and an exception is raised if it’s not available. If True, a PyTorch fallback method is used regardless of CUDA kernel availability.
Note
In e3nn there was a irrep_normalization and path_normalization parameters. This module currently only supports “component” irrep normalization and “element” path normalization.
Forward Pass
- forward( ) Tensor #
Perform the forward pass of the fully connected tensor product operation.
- Parameters:
x1 (torch.Tensor) – Input tensor for the first operand. It should have the shape (batch_size, irreps_in1.dim).
x2 (torch.Tensor) – Input tensor for the second operand. It should have the shape (batch_size, irreps_in2.dim).
weight (torch.Tensor, optional) – Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (weight_numel,) if shared_weights is True. If None, the internal weights are used.
- Returns:
Output tensor resulting from the fully connected tensor product operation. It will have the shape (batch_size, irreps_out.dim).
- Return type:
- Raises:
ValueError – If internal weights are used and weight is not None, or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor.