BatchNorm#
- class cuequivariance_torch.layers.BatchNorm(
- irreps: Irreps,
- *,
- layout: IrrepsLayout = None,
- eps: float = 1e-05,
- momentum: float = 0.1,
- affine: bool = True,
- reduce: str = 'mean',
- instance: bool = False,
- include_bias: bool = True,
Batch normalization for orthonormal representations.
It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations.
- Parameters:
irreps (Irreps) – Input irreps.
layout (IrrepsLayout, optional) – Layout of the input tensor, by default IrrepsLayout.mul_ir.
eps (float, optional) – Epsilon value for numerical stability, by default 1e-5.
momentum (float, optional) – Momentum for the running mean and variance, by default 0.1.
affine (bool, optional) – Whether to apply an affine transformation, by default True.
reduce (str, optional) – How to reduce the norm of the representations, by default “mean”.
instance (bool, optional) – Whether to use instance normalization, by default False.
include_bias (bool, optional) – Whether to include a bias term, by default True.
Forward Pass
- forward(input: Tensor) Tensor #
Normalize the input tensor.
- Parameters:
input (torch.Tensor) – Input tensor. The last dimension should match with the input irreps.
- Returns:
Normalized tensor.
- Return type: