TensorProduct#
- class cuequivariance_torch.TensorProduct(
- descriptor: SegmentedTensorProduct,
- *,
- device: device | None = None,
- math_dtype: dtype | None = None,
- use_fallback: bool | None = None,
PyTorch module that computes the last operand of the segmented tensor product defined by the descriptor.
- Parameters:
descriptor (SegmentedTensorProduct) – The descriptor of the segmented tensor product.
math_dtype (torch.dtype, optional) – The data type of the coefficients and calculations.
device (torch.device, optional) – The device on which the calculations are performed.
use_fallback (bool, optional) – Determines the computation method. If None (default), a CUDA kernel will be used if available. If False, a CUDA kernel will be used, and an exception is raised if it’s not available. If True, a PyTorch fallback method is used regardless of CUDA kernel availability.
Raises – RuntimeError: If use_fallback is False and no CUDA kernel is available.
Forward Pass
- forward(
- x0: Tensor,
- x1: Tensor | None = None,
- x2: Tensor | None = None,
- x3: Tensor | None = None,
- x4: Tensor | None = None,
- x5: Tensor | None = None,
- x6: Tensor | None = None,
Perform the tensor product based on the specified descriptor.
- Parameters:
x0 – The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of (batch, operand_size) or (1, operand_size) where operand_size corresponds to the size of each operand as defined in the tensor product descriptor.
x1[ – The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of (batch, operand_size) or (1, operand_size) where operand_size corresponds to the size of each operand as defined in the tensor product descriptor.
x2 – The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of (batch, operand_size) or (1, operand_size) where operand_size corresponds to the size of each operand as defined in the tensor product descriptor.
x3 – The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of (batch, operand_size) or (1, operand_size) where operand_size corresponds to the size of each operand as defined in the tensor product descriptor.
x4 – The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of (batch, operand_size) or (1, operand_size) where operand_size corresponds to the size of each operand as defined in the tensor product descriptor.
x5 – The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of (batch, operand_size) or (1, operand_size) where operand_size corresponds to the size of each operand as defined in the tensor product descriptor.
x6] – The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of (batch, operand_size) or (1, operand_size) where operand_size corresponds to the size of each operand as defined in the tensor product descriptor.
- Returns:
The output tensor resulting from the tensor product. It has a shape of (batch, last_operand_size), where last_operand_size is the size of the last operand in the descriptor.
- Return type: