Beta features with unstable API#
The API for the following features are likely to change in the future.
Enabling the JIT kernel#
The segmented tensor product with 3 or 4 operands with one mode can be executed using an experimental JIT kernel. Here is how to enable it:
import os
import torch
import cuequivariance as cue
import cuequivariance_torch as cuet
os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1"
e = (
cue.descriptors.channelwise_tensor_product(
128 * cue.Irreps("SO3", "0 + 1 + 2"),
cue.Irreps("SO3", "0 + 1 + 2 + 3"),
cue.Irreps("SO3", "0 + 1 + 2"),
)
.squeeze_modes()
.flatten_coefficient_modes()
)
print(e.ds[0])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
m = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device)
x0 = torch.randn(128, e.inputs[0].dim, device=device)
x1 = torch.randn(128, e.inputs[1].dim, device=device)
x2 = torch.randn(128, e.inputs[2].dim, device=device)
print(m(x0, x1, x2).shape)
u,u,,u sizes=2304,1152,16,8192 num_segments=18,9,16,64 num_paths=207 u=128
torch.Size([128, 8192])
Fused scatter/gather kernel#
Again for segmented tensor product with 3 or 4 operands with one mode, we can use a fused scatter/gather kernel. This kernel is not JIT compiled.
from cuequivariance_torch.primitives.tensor_product import (
TensorProductUniform4x1dIndexed,
)
if device.type == "cuda":
m = TensorProductUniform4x1dIndexed(e.ds[0], device, torch.float32)
x0 = torch.randn(16, e.inputs[0].dim, device=device)
i0 = torch.randint(0, 16, (128,), device=device)
x1 = torch.randn(128, e.inputs[1].dim, device=device)
x2 = torch.randn(128, e.inputs[2].dim, device=device)
i_out = torch.randint(0, 16, (128,), device=device)
print(m(x0, x1, x2, i0, None, None, i_out, 16).shape)