Segmented Polynomial in PyTorch#
The cuequivariance_torch.SegmentedPolynomial class wraps a SegmentedPolynomial into a standard torch.nn.Module.
Basic Usage#
import torch
import cuequivariance as cue
import cuequivariance_torch as cuet
# 1. Define a polynomial: Linear Layer (y = W @ x)
# 2 inputs (W, x), 1 output (y)
# Using a descriptor for convenience
equiv_poly = cue.descriptors.linear(
cue.Irreps("SO3", "4x0"),
cue.Irreps("SO3", "2x0")
)
sp = equiv_poly.polynomial
# 2. Wrap in Module and Execute
model = cuet.SegmentedPolynomial(sp, method="naive")
# Inputs: [Weights, Input Vector]
# PyTorch expects batched inputs (Batch, Dim)
W = torch.randn(1, equiv_poly.inputs[0].dim)
x = torch.randn(1, equiv_poly.inputs[1].dim)
[y] = model([W, x])
print(f"Output shape: {y.shape}")
Output shape: torch.Size([1, 2])
High Performance (Uniform 1D)#
For “Uniform 1D” polynomials (all segments are same-sized 1D vectors), use method="uniform_1d" to enable optimized CUDA kernels.
# Create a Uniform 1D compatible polynomial (Element-wise product)
stp = cue.SegmentedTensorProduct.from_subscripts("i,i,i")
stp.add_segment(0, (32,))
stp.add_segment(1, (32,))
stp.add_segment(2, (32,))
stp.add_path(0, 0, 0, c=1.0)
sp = cue.SegmentedPolynomial(
inputs=[stp.operands[0], stp.operands[1]],
outputs=[stp.operands[2]],
operations=[(cue.Operation([0, 1, 2]), stp)]
)
# Initialize model with uniform_1d method
model = cuet.SegmentedPolynomial(sp, method="uniform_1d")
# Execute with batched inputs
batch_size = 10
x = torch.randn(batch_size, 32)
y = torch.randn(batch_size, 32)
[z] = model([x, y])
print(f"Output shape: {z.shape}")
Output shape: torch.Size([10, 32])
Indexing#
Indexing allows you to route data flexibly, such as applying different weights to different examples. PyTorch uses dictionaries to map operand indices to index tensors.
# Example: Select from 3 weight sets for 10 input examples
num_weights, num_examples, dim = 3, 10, 32
weights = torch.randn(num_weights, dim)
inputs = torch.randn(num_examples, dim)
# Index tensor: Assign weight set 0 to first 3 examples, etc.
w_idx = torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2])
# Execute
# input_indices maps Input 0 (weights) to w_idx
[z] = model(
[weights, inputs],
input_indices={0: w_idx}
)
print(f"Output shape: {z.shape}")
# Output Indexing: Accumulate results into specific bins
# We want 5 output bins. We map the 10 results to bins [0, 0, 1, 1, ...]
out_idx = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
output_shape = torch.empty(5, dim)
[z_accum] = model(
[weights, inputs],
input_indices={0: w_idx},
output_indices={0: out_idx},
output_shapes={0: output_shape}
)
print(f"Accumulated Output shape: {z_accum.shape}")
Output shape: torch.Size([10, 32])
Accumulated Output shape: torch.Size([5, 32])