SegmentedPolynomial#

class cuequivariance_torch.SegmentedPolynomial#

PyTorch module that computes a segmented polynomial.

Parameters:
  • polynomial – The segmented polynomial to compute, an instance of cue.SegmentedPolynomial <cuequivariance.SegmentedPolynomial>.

  • method

    Specifies the implementation method to use. Options are:

    • "naive": Uses a naive PyTorch implementation. It always works but is not optimized.

    • "uniform_1d": Uses a CUDA implementation for polynomials with a single uniform mode.

    • "fused_tp": Uses a CUDA implementation for polynomials with 3- and 4-operand contractions.

    • "indexed_linear": Uses a CUDA implementation for linear layers with indexed weights.

  • math_dtype

    Optional data type for computational operations. If specified, internal buffers will be of this dtype, and operands will be converted to this type for all computations.

    Note

    This will not be affected by changes to the module dtype, and not all methods support all dtypes.

    If math_dtype is not specified:

    • For method "naive", the dtype of the input tensors will be used.

    • For method "uniform_1d", the dtype of the input tensors will be used if allowed (FP32 or FP64), otherwise float32 will be used.

    • For method "fused_tp", the default dtype (FP32) will be used.

    • For method "indexed_linear", the dtype of the input tensors will be used (this is the only option available for this method).

  • output_dtype_map – Optional list that, for each output buffer, specifies the index of the input buffer from which it inherits its data type. -1 means the math_dtype is used. Default is 0 if there are input tensors, otherwise -1.

  • name – Optional name for the operation. Defaults to “segmented_polynomial”.

Examples

Basic usage with spherical harmonics:

>>> import torch
>>> import cuequivariance as cue
>>> from cuequivariance_torch import SegmentedPolynomial
>>>
>>> # Create spherical harmonics polynomial
>>> poly = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]).polynomial
>>> sp = SegmentedPolynomial(poly, method="naive")
>>>
>>> # Compute spherical harmonics for unit vector along y-axis
>>> x = torch.tensor([[0.0, 1.0, 0.0]])
>>> result = sp([x])
>>> print(result[0].shape)
torch.Size([1, 9])

Example with a linear layer:

>>> # Create a linear transformation
>>> input_irreps = cue.Irreps(cue.O3, "5x0e + 3x1o")
>>> output_irreps = cue.Irreps(cue.O3, "4x0e + 2x1o")
>>> poly = cue.descriptors.linear(input_irreps, output_irreps).polynomial
>>>
>>> # Create the module
>>> linear = SegmentedPolynomial(poly, method="naive")
>>>
>>> # Create random weights and input
>>> weights = torch.randn(1, poly.inputs[0].size)
>>> x = torch.randn(10, poly.inputs[1].size)
>>>
>>> # Forward pass
>>> result = linear([weights, x])
>>> print(result[0].shape)
torch.Size([10, 10])

Example with indexed operations:

>>> # Create indexed weights for different elements
>>> weights = torch.randn(3, poly.inputs[0].size)  # 3 different weight sets
>>> x = torch.randn(5, poly.inputs[1].size)        # 5 input vectors
>>>
>>> # Index tensor specifying which weights to use for each input
>>> weight_indices = torch.tensor([0, 1, 0, 2, 1])  # Use weights 0,1,0,2,1
>>>
>>> result = linear([weights, x],
...                input_indices={0: weight_indices})
>>> print(result[0].shape)
torch.Size([5, 10])

Forward Pass

forward(
inputs,
input_indices=None,
output_shapes=None,
output_indices=None,
)#

Compute the segmented polynomial based on the specified descriptor.

Parameters:
  • inputs (List[Tensor]) – The input tensors. The number of input tensors must match the number of input buffers in the descriptor. Each input tensor should have a shape of (batch, operand_size) or (1, operand_size) or (index, operand_size) in the indexed case. Here, operand_size is the size of each operand as defined in the descriptor.

  • input_indices (Dict[int, Tensor] | None) –

    A dictionary that contains an optional indexing tensor for each input tensor. The key is the index into the inputs. If a key is not present, no indexing takes place. The contents of the index tensor must be suitable to index the input tensor (i.e., 0 <= index_tensor[i] < input.shape[0]).

    Note

    Method "indexed_linear" requires the indices to be sorted.

  • output_shapes (Dict[int, Tensor] | None) – A dictionary specifying the size of the output batch dimensions using Tensors. We only read shape_tensor.shape[0]. This is mandatory if the output tensor is indexed. Otherwise, the default shape is (batch, operand_size).

  • output_indices (Dict[int, Tensor] | None) – A dictionary that contains an optional indexing tensor for each output tensor. See input_indices for details.

Returns:

The output tensors resulting from the segmented polynomial. Their shapes are specified just like the inputs.

__init__(
polynomial,
method='',
math_dtype=None,
output_dtype_map=None,
name='segmented_polynomial',
)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
classmethod __new__(*args, **kwargs)#