Rotation#
- class cuequivariance_torch.Rotation#
A class that represents a rotation layer for SO3 or O3 representations.
- Parameters:
irreps (Irreps) – The irreducible representations of the tensor to rotate.
layout (IrrepsLayout, optional) – The memory layout of the tensor,
cue.ir_mul
is preferred.layout_in (IrrepsLayout, optional) – The layout of the input irreducible representations, by default
layout
.layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default
layout
.device (torch.device, optional) – The device to use for the operation.
math_dtype (torch.dtype, optional) – The dtype to use for the math operations, by default it follows the dtype of the input tensors.
method (str, optional) – The method to use for the operation, by default “uniform_1d” (using a CUDA kernel) if all segments have the same shape, otherwise “naive” (using a PyTorch implementation).
use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True the “naive” method is used. If False the “uniform_1d” method is used (make sure all segments have the same shape).
Forward Pass
- forward(gamma, beta, alpha, x)#
Forward pass of the rotation layer.
- Parameters:
gamma (torch.Tensor) – The gamma angles. First rotation around the y-axis.
beta (torch.Tensor) – The beta angles. Second rotation around the x-axis.
alpha (torch.Tensor) – The alpha angles. Third rotation around the y-axis.
x (torch.Tensor) – The input tensor.
- Returns:
The rotated tensor.
- Return type:
- __init__(
- irreps,
- *,
- layout=None,
- layout_in=None,
- layout_out=None,
- device=None,
- math_dtype=None,
- use_fallback=None,
- method=None,
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
irreps (Irreps)
layout (IrrepsLayout | None)
layout_in (IrrepsLayout | None)
layout_out (IrrepsLayout | None)
device (device | None)
math_dtype (dtype | None)
use_fallback (bool | None)
method (str | None)
- classmethod __new__(*args, **kwargs)#