EquivariantTensorProduct#

class cuequivariance_torch.EquivariantTensorProduct(
e: EquivariantTensorProduct,
*,
layout: IrrepsLayout | None = None,
layout_in: IrrepsLayout | tuple[IrrepsLayout | None, ...] | None = None,
layout_out: IrrepsLayout | None = None,
device: device | None = None,
math_dtype: dtype | None = None,
use_fallback: bool | None = None,
)#

Equivariant tensor product.

Parameters:
  • e (cuequivariance.EquivariantTensorProduct) – Equivariant tensor product.

  • layout (IrrepsLayout) – layout for inputs and output.

  • layout_in (IrrepsLayout) – layout for inputs.

  • layout_out (IrrepsLayout) – layout for output.

  • device (torch.device) – device of the Module.

  • math_dtype (torch.dtype) – dtype for internal computations.

  • use_fallback (bool, optional) – Determines the computation method. If None (default), a CUDA kernel will be used if available. If False, a CUDA kernel will be used, and an exception is raised if it’s not available. If True, a PyTorch fallback method is used regardless of CUDA kernel availability.

Raises:

RuntimeError – If use_fallback is False and no CUDA kernel is available.

Examples

>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
>>> e = cue.descriptors.fully_connected_tensor_product(
...    cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1")
... )
>>> w = torch.ones(1, e.inputs[0].dim, device=device)
>>> x1 = torch.ones(17, e.inputs[1].dim, device=device)
>>> x2 = torch.ones(17, e.inputs[2].dim, device=device)
>>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device)
>>> tp(w, x1, x2)
tensor([[0., 0., 0., 0., 0., 0.],...)

You can optionally index the first input tensor:

>>> w = torch.ones(3, e.inputs[0].dim, device=device)
>>> indices = torch.randint(3, (17,))
>>> tp(w, x1, x2, indices=indices)
tensor([[0., 0., 0., 0., 0., 0.],...)

Forward Pass

forward(
x0: Tensor,
x1: Tensor | None = None,
x2: Tensor | None = None,
x3: Tensor | None = None,
indices: Tensor | None = None,
) Tensor#

If indices is not None, the first input is indexed by indices.