SymmetricTensorProduct#
- class cuequivariance_torch.SymmetricTensorProduct(
- descriptors: list[SegmentedTensorProduct],
- *,
- device: device | None = None,
- math_dtype: dtype | None = None,
- optimize_fallback: bool | None = None,
PyTorch module
Parameters#
- descriptorslist[stp.SegmentedTensorProduct]
The list of SegmentedTensorProduct descriptors
- math_dtypetorch.dtype, optional
The data type of the coefficients and calculations
- optimize_fallbackbool, optional
If True, the torch.fx graph will be optimized before execution Because the optimization takes time, it is turned off by default.
- forward( ) Tensor #
Perform the forward pass of the indexed symmetric tensor product operation.
Parameters#
- x0torch.Tensor
The input tensor for the first operand. It should have the shape (batch, x0_size).
- use_fallbackOptional[bool], optional
If None (default), a CUDA kernel will be used if available. If False, a CUDA kernel will be used, and an exception is raised if it’s not available. If True, a PyTorch fallback method is used regardless of CUDA kernel availability.
Returns#
- torch.Tensor
The output tensor resulting from the indexed symmetric tensor product operation. It will have the shape (batch, x1_size).