Linear#
- class cuequivariance_torch.Linear#
A class that represents an equivariant linear layer.
- Parameters:
irreps_in (Irreps) – The input irreducible representations.
irreps_out (Irreps) – The output irreducible representations.
layout (IrrepsLayout, optional) – The layout of the irreducible representations, by default
cue.mul_ir
. This is the layout used in the e3nn library.layout_in (IrrepsLayout, optional) – The layout of the input irreducible representations, by default
layout
.layout_out (IrrepsLayout, optional) – The layout of the output irreducible representations, by default
layout
.shared_weights (bool, optional) – Whether to use shared weights, by default True.
internal_weights (bool, optional) – Whether to use internal weights, by default True if shared_weights is True, otherwise False.
weight_classes (int, optional) – If provided, the weight tensor will have this as a batch dimension (expected if using external weights). If this is specified and >1, at forward time each batch element will use a slice of the weight tensor as indexed by the weight_indices.
device (torch.device, optional) – The device to use for the linear layer.
dtype (torch.dtype, optional) – The dtype to use for the linear layer weights, by default
torch.float32
.math_dtype (torch.dtype, optional) – The dtype to use for the math operations, by default it follows the dtype of the input tensors, if possible, or the torch default dtype.
method (str, optional) – The method to use for the linear layer, by default “naive” (using a PyTorch implementation).
use_fallback (bool, optional, deprecated) – Whether to use a “fallback” implementation, now maps to method: If True or None (default), the “naive” method is used. If False, the “fused_tp” method is used.
Forward Pass
- forward(x, weight=None, weight_indices=None)#
Forward pass of the linear layer.
- Parameters:
x (torch.Tensor) – The input tensor.
weight (torch.Tensor, optional) – The weight tensor. If None, the internal weight tensor is used, otherwise: If weights are not shared, this should be a tensor of shape (batch_size, weight_numel). If weights are shared, this should be a tensor of shape (weight_classes, weight_numel) (where weight_classes is 1 if unspecified).
weight_indices (torch.Tensor, optional) – The indices of the weight tensor: if weight_classes > 1, this is an integer tensor of shape (batch_size,), indicating which weight slice to use for each batch element.
- Returns:
The output tensor after applying the linear transformation.
- Return type:
- Raises:
ValueError – If internal weights are used and weight is not None, or if shared weights are used and weight is not a 1D tensor, or if shared weights are not used and weight is not a 2D tensor.
- __init__(
- irreps_in,
- irreps_out,
- *,
- layout=None,
- layout_in=None,
- layout_out=None,
- shared_weights=True,
- internal_weights=None,
- weight_classes=1,
- device=None,
- dtype=None,
- math_dtype=None,
- use_fallback=None,
- method=None,
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
irreps_in (Irreps)
irreps_out (Irreps)
layout (IrrepsLayout | None)
layout_in (IrrepsLayout | None)
layout_out (IrrepsLayout | None)
shared_weights (bool)
internal_weights (bool)
weight_classes (int | None)
device (device | None)
dtype (dtype | None)
math_dtype (dtype | None)
use_fallback (bool | None)
method (str | None)
- classmethod __new__(*args, **kwargs)#