cuequivariance_torch.layers.BatchNorm

class cuequivariance_torch.layers.BatchNorm(irreps: Irreps, *, layout: IrrepsLayout = None, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, reduce: str = 'mean', instance: bool = False, include_bias: bool = True)

Batch normalization for orthonormal representations.

It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations.

Parameters:
  • irreps (Irreps) – Input irreps.

  • layout (IrrepsLayout, optional) – Layout of the input tensor, by default IrrepsLayout.mul_ir.

  • eps (float, optional) – Epsilon value for numerical stability, by default 1e-5.

  • momentum (float, optional) – Momentum for the running mean and variance, by default 0.1.

  • affine (bool, optional) – Whether to apply an affine transformation, by default True.

  • reduce (str, optional) – How to reduce the norm of the representations, by default “mean”.

  • instance (bool, optional) – Whether to use instance normalization, by default False.

  • include_bias (bool, optional) – Whether to include a bias term, by default True.

forward(input: Tensor) Tensor

Normalize the input tensor.

Parameters:

input (torch.Tensor) – Input tensor. The last dimension should match with the input irreps.

Returns:

Normalized tensor.

Return type:

torch.Tensor