Linear#
- class cuequivariance_torch.Linear#
A class that represents an equivariant linear layer.
- Parameters:
irreps_in (Irreps) – The input irreducible representations.
irreps_out (Irreps) – The output irreducible representations.
layout (IrrepsLayout, optional) – The layout of the irreducible representations, by default
cue.mul_ir
. This is the layout used in the e3nn library.layout_in (IrrepsLayout, optional) – The layout of the input irreducible representations, by default
layout
.layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default
layout
.shared_weights (bool, optional) – Whether to use shared weights, by default True.
internal_weights (bool, optional) – Whether to use internal weights, by default True if shared_weights is True, otherwise False.
weight_classes (int, optional) – If provided, the weight tensor will have this as a batch dimension (expected if using external weights). If this is specified and >1, at forward time each batch element will use a slice of the weight tensor as indexed by the weight_indices.
device (torch.device, optional) – The device to use for the linear layer.
dtype (torch.dtype, optional) – The dtype to use for the linear layer weights, by default
torch.float32
.math_dtype (torch.dtype or string, optional) – The dtype to use for the math operations, by default it follows the dtype of the input tensors, if possible, or the torch default dtype (see SegmentedPolynomial for more details).
method (str, optional) – The method to use for the linear layer, by default “naive” (using a PyTorch implementation).
use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True or None (default), the “naive” method is used. If False, the “fused_tp” method is used.
Forward Pass
- forward(x, weight=None, weight_indices=None)#
Forward pass of the linear layer.
- Parameters:
x (torch.Tensor) – The input tensor.
weight (torch.Tensor, optional) – The weight tensor. If None, the internal weight tensor is used, otherwise: If weights are not shared, this should be a tensor of shape (batch_size, weight_numel). If weights are shared, this should be a tensor of shape (weight_classes, weight_numel) (where weight_classes is 1 if unspecified).
weight_indices (torch.Tensor, optional) – The indices of the weight tensor: if weight_classes > 1, this is an integer tensor of shape (batch_size,), indicating which weight slice to use for each batch element.
- Returns:
The output tensor after applying the linear transformation.
- Return type:
- Raises:
ValueError – If internal weights are used and weight is not None, or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor.
- __init__(
- irreps_in,
- irreps_out,
- *,
- layout=None,
- layout_in=None,
- layout_out=None,
- shared_weights=True,
- internal_weights=None,
- weight_classes=1,
- device=None,
- dtype=None,
- math_dtype=None,
- use_fallback=None,
- method=None,
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
irreps_in (Irreps)
irreps_out (Irreps)
layout (IrrepsLayout | None)
layout_in (IrrepsLayout | None)
layout_out (IrrepsLayout | None)
shared_weights (bool)
internal_weights (bool)
weight_classes (int | None)
device (device | None)
dtype (dtype | None)
use_fallback (bool | None)
method (str | None)
- classmethod __new__(*args, **kwargs)#