ChannelWiseTensorProduct#

class cuequivariance_torch.ChannelWiseTensorProduct(
irreps_in1: Irreps,
irreps_in2: Irreps,
filter_irreps_out: Sequence[Irrep] = None,
*,
layout: IrrepsLayout | None = None,
layout_in1: IrrepsLayout | None = None,
layout_in2: IrrepsLayout | None = None,
layout_out: IrrepsLayout | None = None,
shared_weights: bool = True,
internal_weights: bool = None,
device: device | None = None,
dtype: dtype | None = None,
math_dtype: dtype | None = None,
use_fallback: bool | None = None,
)#

Channel-wise tensor product layer.

Parameters:
  • irreps_in1 (Irreps) – Input irreps for the first operand.

  • irreps_in2 (Irreps) – Input irreps for the second operand.

  • filter_irreps_out (Sequence of Irrep, optional) – Filter for the output irreps. Default is None.

  • layout (IrrepsLayout, optional) – The layout of the input and output irreps. Default is cue.mul_ir which is the layout corresponding to e3nn.

  • shared_weights (bool, optional) – Whether to share weights across the batch dimension. Default is True.

  • internal_weights (bool, optional) – Whether to create module parameters for weights. Default is None.

  • use_fallback (bool, optional) – If None (default), a CUDA kernel will be used if available. If False, a CUDA kernel will be used, and an exception is raised if it’s not available. If True, a PyTorch fallback method is used regardless of CUDA kernel availability.

Note

In e3nn there was a irrep_normalization and path_normalization parameters. This module currently only supports “component” irrep normalization and “element” path normalization.

Forward Pass

forward(
x1: Tensor,
x2: Tensor,
weight: Tensor | None = None,
) Tensor#

Perform the forward pass of the fully connected tensor product operation.

Parameters:
  • x1 (torch.Tensor) – Input tensor for the first operand. It should have the shape (batch_size, irreps_in1.dim).

  • x2 (torch.Tensor) – Input tensor for the second operand. It should have the shape (batch_size, irreps_in2.dim).

  • weight (torch.Tensor, optional) – Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (weight_numel,) if shared_weights is True. If None, the internal weights are used.

Returns:

Output tensor resulting from the fully connected tensor product operation. It will have the shape (batch_size, irreps_out.dim).

Return type:

torch.Tensor

Raises:

ValueError – If internal weights are used and weight is not None, or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor.