Rotation#

class cuequivariance_torch.Rotation#

A class that represents a rotation layer for SO3 or O3 representations.

Parameters:
  • irreps (Irreps) – The irreducible representations of the tensor to rotate.

  • layout (IrrepsLayout, optional) – The memory layout of the tensor, cue.ir_mul is preferred.

  • layout_in (IrrepsLayout, optional) – The layout of the input irreducible representations, by default layout.

  • layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default layout.

  • device (torch.device, optional) – The device to use for the operation.

  • math_dtype (torch.dtype, optional) – The dtype to use for the math operations, by default it follows the dtype of the input tensors.

  • method (str, optional) – The method to use for the operation, by default “uniform_1d” (using a CUDA kernel) if all segments have the same shape, otherwise “naive” (using a PyTorch implementation).

  • use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True the “naive” method is used. If False the “uniform_1d” method is used (make sure all segments have the same shape).

Forward Pass

forward(gamma, beta, alpha, x)#

Forward pass of the rotation layer.

Parameters:
  • gamma (torch.Tensor) – The gamma angles. First rotation around the y-axis.

  • beta (torch.Tensor) – The beta angles. Second rotation around the x-axis.

  • alpha (torch.Tensor) – The alpha angles. Third rotation around the y-axis.

  • x (torch.Tensor) – The input tensor.

Returns:

The rotated tensor.

Return type:

torch.Tensor

__init__(
irreps,
*,
layout=None,
layout_in=None,
layout_out=None,
device=None,
math_dtype=None,
use_fallback=None,
method=None,
)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
classmethod __new__(*args, **kwargs)#