BatchNorm#

class cuequivariance_torch.layers.BatchNorm(
irreps: Irreps,
*,
layout: IrrepsLayout = None,
eps: float = 1e-05,
momentum: float = 0.1,
affine: bool = True,
reduce: str = 'mean',
instance: bool = False,
include_bias: bool = True,
)#

Batch normalization for orthonormal representations.

It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations.

Parameters:
  • irreps (Irreps) – Input irreps.

  • layout (IrrepsLayout, optional) – Layout of the input tensor, by default IrrepsLayout.mul_ir.

  • eps (float, optional) – Epsilon value for numerical stability, by default 1e-5.

  • momentum (float, optional) – Momentum for the running mean and variance, by default 0.1.

  • affine (bool, optional) – Whether to apply an affine transformation, by default True.

  • reduce (str, optional) – How to reduce the norm of the representations, by default “mean”.

  • instance (bool, optional) – Whether to use instance normalization, by default False.

  • include_bias (bool, optional) – Whether to include a bias term, by default True.

Forward Pass

forward(input: Tensor) Tensor#

Normalize the input tensor.

Parameters:

input (torch.Tensor) – Input tensor. The last dimension should match with the input irreps.

Returns:

Normalized tensor.

Return type:

torch.Tensor