SymmetricTensorProduct#

class cuequivariance_torch.SymmetricTensorProduct(
descriptors: list[SegmentedTensorProduct],
*,
device: device | None = None,
math_dtype: dtype | None = None,
optimize_fallback: bool | None = None,
)#

PyTorch module

Parameters#

descriptorslist[stp.SegmentedTensorProduct]

The list of SegmentedTensorProduct descriptors

math_dtypetorch.dtype, optional

The data type of the coefficients and calculations

optimize_fallbackbool, optional

If True, the torch.fx graph will be optimized before execution Because the optimization takes time, it is turned off by default.

forward(
x0: Tensor,
use_fallback: bool | None = None,
) Tensor#

Perform the forward pass of the indexed symmetric tensor product operation.

Parameters#

x0torch.Tensor

The input tensor for the first operand. It should have the shape (batch, x0_size).

use_fallbackOptional[bool], optional

If None (default), a CUDA kernel will be used if available. If False, a CUDA kernel will be used, and an exception is raised if it’s not available. If True, a PyTorch fallback method is used regardless of CUDA kernel availability.

Returns#

torch.Tensor

The output tensor resulting from the indexed symmetric tensor product operation. It will have the shape (batch, x1_size).