SegmentedPolynomial#

class cuequivariance_torch.SegmentedPolynomial(
polynomial: SegmentedPolynomial,
math_dtype: dtype = torch.float32,
output_dtype_map: List[int] = None,
name: str = 'segmented_polynomial',
)#

PyTorch module that computes a segmented polynomial.

Currently, it supports segmented polynomials where all segment sizes are the same, and each operand is one or zero dimensional.

Parameters:
  • polynomial – The segmented polynomial to compute, an instance of cue.SegmentedPolynomial <cuequivariance.SegmentedPolynomial>.

  • math_dtype – Data type for computational operations, defaulting to float32.

  • output_dtype_map – Optional list that, for each output buffer, specifies the index of the input buffer from which it inherits its data type. -1 means the math_dtype is used. Default 0 if there are input tensors, otherwise -1.

  • name – Optional name for the operation. Defaults to “segmented_polynomial”.

Forward Pass

forward(
inputs: List[Tensor],
input_indices: Dict[int, Tensor] | None = None,
output_shapes: Dict[int, Tensor] | None = None,
output_indices: Dict[int, Tensor] | None = None,
)#

Computes the segmented polynomial based on the specified descriptor.

Parameters:
  • inputs – The input tensors. The number of input tensors must match the number of input buffers in the descriptor. Each input tensor should have a shape of (batch, operand_size) or (1, operand_size) or (index, operand_size) in the indexed case. Here, operand_size is the size of each operand as defined in the descriptor.

  • input_indices – A dictionary that contains an optional indexing tensor for each input tensor. The key is the index into the inputs. If a key is not present, no indexing takes place. The contents of the index tensor must be suitable to index the input tensor (i.e. 0 <= index_tensor[i] < input.shape[0].

  • output_shapes – A dictionary specifying the size of the output batch dimensions using Tensors. We only read shape_tensor.shape[0]. This is mandatory if the output tensor is indexed. Otherwise, the default shape is (batch, operand_size).

  • output_indices – A dictionary that contains an optional indexing tensor for each output tensor. See input_indices for details.

Returns:

The output tensors resulting from the segmented polynomial. Their shapes are specified just like the inputs.

Return type:

List[torch.Tensor]