Beta features with unstable API#

The API for the following features are likely to change in the future.

Enabling the JIT kernel#

The segmented tensor product with 3 or 4 operands with one mode can be executed using an experimental JIT kernel. Here is how to enable it:

import os
import torch

import cuequivariance as cue
import cuequivariance_torch as cuet

os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1"

e = (
    cue.descriptors.channelwise_tensor_product(
        128 * cue.Irreps("SO3", "0 + 1 + 2"),
        cue.Irreps("SO3", "0 + 1 + 2 + 3"),
        cue.Irreps("SO3", "0 + 1 + 2"),
    )
    .squeeze_modes()
    .flatten_coefficient_modes()
)
print(e.ds[0])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
m = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device)
x0 = torch.randn(128, e.inputs[0].dim, device=device)
x1 = torch.randn(128, e.inputs[1].dim, device=device)
x2 = torch.randn(128, e.inputs[2].dim, device=device)
print(m(x0, x1, x2).shape)
u,u,,u sizes=2304,1152,16,8192 num_segments=18,9,16,64 num_paths=207 u=128
torch.Size([128, 8192])

Fused scatter/gather kernel#

Again for segmented tensor product with 3 or 4 operands with one mode, we can use a fused scatter/gather kernel. This kernel is not JIT compiled.

from cuequivariance_torch.primitives.tensor_product import (
    TensorProductUniform4x1dIndexed,
)

if device.type == "cuda":
    m = TensorProductUniform4x1dIndexed(e.ds[0], device, torch.float32)

    x0 = torch.randn(16, e.inputs[0].dim, device=device)
    i0 = torch.randint(0, 16, (128,), device=device)
    x1 = torch.randn(128, e.inputs[1].dim, device=device)
    x2 = torch.randn(128, e.inputs[2].dim, device=device)
    i_out = torch.randint(0, 16, (128,), device=device)
    print(m(x0, x1, x2, i0, None, None, i_out, 16).shape)