Normalization#

class physicsnemo.nn.module.group_norm.GroupNorm(
num_channels: int,
num_groups: int = 32,
min_channels_per_group: int = 4,
eps: float = 1e-05,
use_apex_gn: bool = False,
fused_act: bool = False,
act: str | None = None,
amp_mode: bool = False,
)[source]#

Bases: Module

A custom Group Normalization layer implementation.

Group Normalization (GN) divides the channels of the input tensor into groups and normalizes the features within each group independently. It does not require the batch size as in Batch Normalization, making it suitable for batch sizes of any size or even for batch-free scenarios.

Parameters:
  • num_channels (int) – Number of channels in the input tensor.

  • num_groups (int, optional, default=32) – Desired number of groups to divide the input channels. This might be adjusted based on the min_channels_per_group.

  • min_channels_per_group (int, optional, default=4) – Minimum channels required per group. This ensures that no group has fewer channels than this number.

  • eps (float, optional, default=1e-5) – A small number added to the variance to prevent division by zero.

  • use_apex_gn (bool, optional, default=False) – Deprecated. Please use get_group_norm() instead.

  • fused_act (bool, optional, default=False) – Deprecated. Please use get_group_norm() instead.

  • act (str, optional, default=None) – The activation function to use when fusing activation with GroupNorm.

  • amp_mode (bool, optional, default=False) – A boolean flag indicating whether mixed-precision (AMP) training is enabled.

Forward:

x (torch.Tensor) – 4-D input tensor of shape \((B, C, H, W)\), where \(B\) is batch size, \(C\) is num_channels, and \(H, W\) are spatial dimensions.

Outputs:
  • torch.Tensor – Output tensor of the same shape as input: \((B, C, H, W)\).

  • .. note::

  • If num_channels is not divisible by num_groups, the actual number of

  • groups might be adjusted to satisfy the min_channels_per_group condition.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_activation_function()[source]#

Get activation function given string input

physicsnemo.nn.module.group_norm.get_group_norm(
num_channels: int,
num_groups: int = 32,
min_channels_per_group: int = 4,
eps: float = 1e-05,
use_apex_gn: bool = False,
act: str | None = None,
amp_mode: bool = False,
) Module[source]#

Utility function to get the GroupNorm layer, either from apex or from torch.

Parameters:
  • num_channels (int) – Number of channels in the input tensor.

  • num_groups (int, optional, default=32) – Desired number of groups to divide the input channels. This might be adjusted based on the min_channels_per_group.

  • min_channels_per_group (int, optional, default=4) – Minimum channels required per group. This ensures that no group has fewer channels than this number.

  • eps (float, optional, default=1e-5) – A small number added to the variance to prevent division by zero.

  • use_apex_gn (bool, optional, default=False) – A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. Need to set this as False on cpu.

  • act (str, optional, default=None) – The activation function to use when fusing activation with GroupNorm.

  • amp_mode (bool, optional, default=False) – A boolean flag indicating whether mixed-precision (AMP) training is enabled.

Returns:

  • torch.nn.Module – The GroupNorm layer. If use_apex_gn is True, returns an ApexGroupNorm layer, otherwise returns an instance of GroupNorm.

  • .. note::

  • If num_channels is not divisible by num_groups, the actual number

  • of groups might be adjusted to satisfy the min_channels_per_group

  • condition.

class physicsnemo.nn.module.layer_norm.LayerNorm(*args, **kwargs)#

Bases: LayerNorm

Wrapper around layer norm utilities.

This class will default to using the transformer engine implementation of LayerNorm - it is significantly faster in the backwards pass.

If transformer engine is not available, it will fall back to the pytorch implementation of LayerNorm.

Additionally, this class registers pre or post hooks to allow you to train with / without transformer engine, and run inference with / without transformer engine.

Note

Transformer engine adds additional state parameters that affect fp8 stability. Do NOT switch from transformer engine to pytorch or from pytorch to transformer engine with a checkpoint if you are using fp8 precision in the layer norm regions.

physicsnemo.nn.module.layer_norm.get_layer_norm_class() Module[source]#

Dynamically pick the layer norm provider based on availability of transformer engine. If transformer engine is available, it will use the transformer engine implementation of LayerNorm. Otherwise, it will use the pytorch implementation of LayerNorm.

Override the default behavior by setting the PHYSICSNEMO_FORCE_TE environment variable.

physicsnemo.nn.module.layer_norm.ignore_missing_extra_state_key(
module: Module,
incompatible_keys: _IncompatibleKeys,
) None[source]#

Post-hook to ignore missing ‘ln.norm._extra_state’ key when loading state_dict.

This function removes ‘ln.norm._extra_state’ from the list of missing keys in the IncompatibleKeys object. This is useful when loading a checkpoint saved from a Transformer Engine LayerNorm into a PyTorch LayerNorm, where this extra state is not present or needed.

Parameters:
  • module (nn.Module) – The module into which the state_dict is being loaded.

  • incompatible_keys – An object with a ‘missing_keys’ attribute (typically torch.nn.modules.module._IncompatibleKeys).

physicsnemo.nn.module.layer_norm.remove_extra_state_hook_for_torch(
module: Module,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: list,
unexpected_keys: list,
error_msgs: list,
) None[source]#

Pre-hook to remove Transformer Engine’s extra state from the state_dict when loading into a PyTorch LayerNorm.

This function scans the state_dict for any keys that match the pattern ‘{prefix}norm._extra_state’ and removes them. These keys are specific to Transformer Engine’s LayerNorm and are not needed (and may cause errors) when loading into a standard PyTorch LayerNorm.

Parameters:
  • module (nn.Module) – The module into which the state_dict is being loaded.

  • state_dict (dict) – The state dictionary being loaded.

  • prefix (str) – The prefix for parameters in this module.

  • local_metadata (dict) – Metadata for this module.

  • strict (bool) – Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict function.

  • missing_keys (list) – List of missing keys.

  • unexpected_keys (list) – List of unexpected keys.

  • error_msgs (list) – List of error messages.