ChannelWiseTensorProduct#
- class cuequivariance_torch.ChannelWiseTensorProduct#
Channel-wise tensor product layer.
- Parameters:
irreps_in1 (Irreps) – Input irreps for the first operand.
irreps_in2 (Irreps) – Input irreps for the second operand.
filter_irreps_out (Sequence of Irrep, optional) – Filter for the output irreps. Default is None.
layout (IrrepsLayout, optional) – The layout of the input and output irreps. Default is
cue.mul_ir
which is the layout corresponding to e3nn.layout_in1 (IrrepsLayout, optional) – The layout of the first input irreducible representations, by default
layout
.layout_in2 (IrrepsLayout, optional) – The layout of the second input irreducible representations, by default
layout
.layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default
layout
.shared_weights (bool, optional) – Whether to share weights across the batch dimension. Default is True.
internal_weights (bool, optional) – Whether to create module parameters for weights. Default is None.
device (torch.device, optional) – The device to use for the operation.
dtype (torch.dtype, optional) – The dtype to use for the operation weights, by default
torch.float32
.math_dtype (torch.dtype or string, optional) – The dtype to use for the math operations, by default it follows the dtype of the input tensors, if possible, or the torch default dtype (see SegmentedPolynomial for more details).
method (str, optional) – The method to use for the operation, by default “uniform_1d” (using a CUDA kernel) if all segments have the same shape, otherwise “naive” (using a PyTorch implementation).
use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True the “naive” method is used. If False the “uniform_1d” method is used (make sure all segments have the same shape).
Note
In e3nn there was a irrep_normalization and path_normalization parameters. This module currently only supports “component” irrep normalization and “element” path normalization.
Forward Pass
- forward(
- x1,
- x2,
- weight=None,
- indices_1=None,
- indices_2=None,
- indices_out=None,
- size_out=None,
Perform the forward pass of the channel-wise tensor product operation.
- Parameters:
x1 (torch.Tensor) – Input tensor for the first operand. It should have the shape (:, irreps_in1.dim).
x2 (torch.Tensor) – Input tensor for the second operand. It should have the shape (:, irreps_in2.dim).
weight (torch.Tensor, optional) – Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (1, weight_numel) if shared_weights is True. If None, the internal weights are used.
indices_1 (torch.Tensor, optional) – Indices to gather elements for the first operand.
indices_2 (torch.Tensor, optional) – Indices to gather elements for the second operand.
indices_out (torch.Tensor, optional) – Indices to scatter elements for the output.
size_out (int, optional) – Batch dimension of the output. Needed if indices_out are provided.
- Returns:
Output tensor resulting from the channel-wise tensor product operation. It will have the shape (batch_size, irreps_out.dim).
- Return type:
- Raises:
ValueError – If internal weights are used and weight is not None, or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor. or if size_out is not provided and indices_out is provided.
- __init__(
- irreps_in1,
- irreps_in2,
- filter_irreps_out=None,
- *,
- layout=None,
- layout_in1=None,
- layout_in2=None,
- layout_out=None,
- shared_weights=True,
- internal_weights=None,
- device=None,
- dtype=None,
- math_dtype=None,
- use_fallback=None,
- method=None,
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
irreps_in1 (Irreps)
irreps_in2 (Irreps)
layout (IrrepsLayout | None)
layout_in1 (IrrepsLayout | None)
layout_in2 (IrrepsLayout | None)
layout_out (IrrepsLayout | None)
shared_weights (bool)
internal_weights (bool)
device (device | None)
dtype (dtype | None)
use_fallback (bool | None)
method (str | None)
- classmethod __new__(*args, **kwargs)#