EquivariantTensorProduct#
- class cuequivariance_torch.EquivariantTensorProduct(
- e: EquivariantTensorProduct,
- *,
- layout: IrrepsLayout | None = None,
- layout_in: IrrepsLayout | tuple[IrrepsLayout | None, ...] | None = None,
- layout_out: IrrepsLayout | None = None,
- device: device | None = None,
- math_dtype: dtype | None = None,
- use_fallback: bool | None = None,
Equivariant tensor product.
- Parameters:
e (cuequivariance.EquivariantTensorProduct) – Equivariant tensor product.
layout (IrrepsLayout) – layout for inputs and output.
layout_in (IrrepsLayout) – layout for inputs.
layout_out (IrrepsLayout) – layout for output.
device (torch.device) – device of the Module.
math_dtype (torch.dtype) – dtype for internal computations.
use_fallback (bool, optional) – Determines the computation method. If None (default), a CUDA kernel will be used if available. If False, a CUDA kernel will be used, and an exception is raised if it’s not available. If True, a PyTorch fallback method is used regardless of CUDA kernel availability.
- Raises:
RuntimeError – If use_fallback is False and no CUDA kernel is available.
Examples
>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") >>> e = cue.descriptors.fully_connected_tensor_product( ... cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1") ... ) >>> w = torch.ones(1, e.inputs[0].dim, device=device) >>> x1 = torch.ones(17, e.inputs[1].dim, device=device) >>> x2 = torch.ones(17, e.inputs[2].dim, device=device) >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device) >>> tp(w, x1, x2) tensor([[0., 0., 0., 0., 0., 0.],...)
You can optionally index the first input tensor:
>>> w = torch.ones(3, e.inputs[0].dim, device=device) >>> indices = torch.randint(3, (17,)) >>> tp(w, x1, x2, indices=indices) tensor([[0., 0., 0., 0., 0., 0.],...)
Forward Pass