ChannelWiseTensorProduct#

class cuequivariance_torch.ChannelWiseTensorProduct#

Channel-wise tensor product layer.

Parameters:
  • irreps_in1 (Irreps) – Input irreps for the first operand.

  • irreps_in2 (Irreps) – Input irreps for the second operand.

  • filter_irreps_out (Sequence of Irrep, optional) – Filter for the output irreps. Default is None.

  • layout (IrrepsLayout, optional) – The layout of the input and output irreps. Default is cue.mul_ir which is the layout corresponding to e3nn.

  • layout_in1 (IrrepsLayout, optional) – The layout of the first input irreducible representations, by default layout.

  • layout_in2 (IrrepsLayout, optional) – The layout of the second input irreducible representations, by default layout.

  • layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default layout.

  • shared_weights (bool, optional) – Whether to share weights across the batch dimension. Default is True.

  • internal_weights (bool, optional) – Whether to create module parameters for weights. Default is None.

  • device (torch.device, optional) – The device to use for the operation.

  • dtype (torch.dtype, optional) – The dtype to use for the operation weights, by default torch.float32.

  • math_dtype (torch.dtype, optional) – The dtype to use for the math operations, by default it follows the dtype of the input tensors.

  • method (str, optional) – The method to use for the operation, by default “uniform_1d” (using a CUDA kernel) if all segments have the same shape, otherwise “naive” (using a PyTorch implementation).

  • use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True the “naive” method is used. If False the “uniform_1d” method is used (make sure all segments have the same shape).

Note

In e3nn there was a irrep_normalization and path_normalization parameters. This module currently only supports “component” irrep normalization and “element” path normalization.

Forward Pass

forward(
x1,
x2,
weight=None,
indices_1=None,
indices_2=None,
indices_out=None,
size_out=None,
)#

Perform the forward pass of the channel-wise tensor product operation.

Parameters:
  • x1 (torch.Tensor) – Input tensor for the first operand. It should have the shape (:, irreps_in1.dim).

  • x2 (torch.Tensor) – Input tensor for the second operand. It should have the shape (:, irreps_in2.dim).

  • weight (torch.Tensor, optional) – Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (1, weight_numel) if shared_weights is True. If None, the internal weights are used.

  • indices_1 (torch.Tensor, optional) – Indices to gather elements for the first operand.

  • indices_2 (torch.Tensor, optional) – Indices to gather elements for the second operand.

  • indices_out (torch.Tensor, optional) – Indices to scatter elements for the output.

  • size_out (int, optional) – Batch dimension of the output. Needed if indices_out are provided.

Returns:

Output tensor resulting from the channel-wise tensor product operation. It will have the shape (batch_size, irreps_out.dim).

Return type:

torch.Tensor

Raises:

ValueError – If internal weights are used and weight is not None, or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor. or if size_out is not provided and indices_out is provided.

__init__(
irreps_in1,
irreps_in2,
filter_irreps_out=None,
*,
layout=None,
layout_in1=None,
layout_in2=None,
layout_out=None,
shared_weights=True,
internal_weights=None,
device=None,
dtype=None,
math_dtype=None,
use_fallback=None,
method=None,
)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
classmethod __new__(*args, **kwargs)#