ChannelWiseTensorProduct#
- class cuequivariance_torch.ChannelWiseTensorProduct#
Channel-wise tensor product layer.
- Parameters:
irreps_in1 (Irreps) – Input irreps for the first operand.
irreps_in2 (Irreps) – Input irreps for the second operand.
filter_irreps_out (Sequence of Irrep, optional) – Filter for the output irreps. Default is None.
layout (IrrepsLayout, optional) – The layout of the input and output irreps. Default is
cue.mul_ir
which is the layout corresponding to e3nn.layout_in1 (IrrepsLayout, optional) – The layout of the first input irreducible representations, by default
layout
.layout_in2 (IrrepsLayout, optional) – The layout of the second input irreducible representations, by default
layout
.layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default
layout
.shared_weights (bool, optional) – Whether to share weights across the batch dimension. Default is True.
internal_weights (bool, optional) – Whether to create module parameters for weights. Default is None.
device (torch.device, optional) – The device to use for the operation.
dtype (torch.dtype, optional) – The dtype to use for the operation weights, by default
torch.float32
.math_dtype (torch.dtype, optional) – The dtype to use for the math operations, by default it follows the dtype of the input tensors.
method (str, optional) – The method to use for the operation, by default “uniform_1d” (using a CUDA kernel) if all segments have the same shape, otherwise “naive” (using a PyTorch implementation).
use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True the “naive” method is used. If False the “uniform_1d” method is used (make sure all segments have the same shape).
Note
In e3nn there was a irrep_normalization and path_normalization parameters. This module currently only supports “component” irrep normalization and “element” path normalization.
Forward Pass
- forward(
- x1,
- x2,
- weight=None,
- indices_1=None,
- indices_2=None,
- indices_out=None,
- size_out=None,
Perform the forward pass of the channel-wise tensor product operation.
- Parameters:
x1 (torch.Tensor) – Input tensor for the first operand. It should have the shape (:, irreps_in1.dim).
x2 (torch.Tensor) – Input tensor for the second operand. It should have the shape (:, irreps_in2.dim).
weight (torch.Tensor, optional) – Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (1, weight_numel) if shared_weights is True. If None, the internal weights are used.
indices_1 (torch.Tensor, optional) – Indices to gather elements for the first operand.
indices_2 (torch.Tensor, optional) – Indices to gather elements for the second operand.
indices_out (torch.Tensor, optional) – Indices to scatter elements for the output.
size_out (int, optional) – Batch dimension of the output. Needed if indices_out are provided.
- Returns:
Output tensor resulting from the channel-wise tensor product operation. It will have the shape (batch_size, irreps_out.dim).
- Return type:
- Raises:
ValueError – If internal weights are used and weight is not None, or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor. or if size_out is not provided and indices_out is provided.
- __init__(
- irreps_in1,
- irreps_in2,
- filter_irreps_out=None,
- *,
- layout=None,
- layout_in1=None,
- layout_in2=None,
- layout_out=None,
- shared_weights=True,
- internal_weights=None,
- device=None,
- dtype=None,
- math_dtype=None,
- use_fallback=None,
- method=None,
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
irreps_in1 (Irreps)
irreps_in2 (Irreps)
layout (IrrepsLayout | None)
layout_in1 (IrrepsLayout | None)
layout_in2 (IrrepsLayout | None)
layout_out (IrrepsLayout | None)
shared_weights (bool)
internal_weights (bool)
device (device | None)
dtype (dtype | None)
math_dtype (dtype | None)
use_fallback (bool | None)
method (str | None)
- classmethod __new__(*args, **kwargs)#