SegmentedPolynomial#
- class cuequivariance_torch.SegmentedPolynomial(
- polynomial: SegmentedPolynomial,
- math_dtype: dtype = torch.float32,
- output_dtype_map: List[int] = None,
- name: str = 'segmented_polynomial',
PyTorch module that computes a segmented polynomial.
Currently, it supports segmented polynomials where all segment sizes are the same, and each operand is one or zero dimensional.
- Parameters:
polynomial – The segmented polynomial to compute, an instance of cue.SegmentedPolynomial <cuequivariance.SegmentedPolynomial>.
math_dtype – Data type for computational operations, defaulting to float32.
output_dtype_map – Optional list that, for each output buffer, specifies the index of the input buffer from which it inherits its data type. -1 means the math_dtype is used. Default 0 if there are input tensors, otherwise -1.
name – Optional name for the operation. Defaults to “segmented_polynomial”.
Forward Pass
- forward(
- inputs: List[Tensor],
- input_indices: Dict[int, Tensor] | None = None,
- output_shapes: Dict[int, Tensor] | None = None,
- output_indices: Dict[int, Tensor] | None = None,
Computes the segmented polynomial based on the specified descriptor.
- Parameters:
inputs – The input tensors. The number of input tensors must match the number of input buffers in the descriptor. Each input tensor should have a shape of (batch, operand_size) or (1, operand_size) or (index, operand_size) in the indexed case. Here, operand_size is the size of each operand as defined in the descriptor.
input_indices – A dictionary that contains an optional indexing tensor for each input tensor. The key is the index into the inputs. If a key is not present, no indexing takes place. The contents of the index tensor must be suitable to index the input tensor (i.e. 0 <= index_tensor[i] < input.shape[0].
output_shapes – A dictionary specifying the size of the output batch dimensions using Tensors. We only read shape_tensor.shape[0]. This is mandatory if the output tensor is indexed. Otherwise, the default shape is (batch, operand_size).
output_indices – A dictionary that contains an optional indexing tensor for each output tensor. See input_indices for details.
- Returns:
The output tensors resulting from the segmented polynomial. Their shapes are specified just like the inputs.
- Return type:
List[torch.Tensor]