core.transformer.torch_norm#

Module Contents#

Classes#

LayerNormInterface

Interface that all LayerNorm implementations should follow.

LayerNormBuilder

A protocol showing how Modules are expected to construct LayerNorms.

WrappedTorchNorm

A conditional wrapper to initialize an instance of PyTorch’s LayerNorm or RMSNorm based on input

L2Norm

Applies L2 normalization to the input tensor along the last dimension.

API#

class core.transformer.torch_norm.LayerNormInterface#

Bases: typing.Protocol

Interface that all LayerNorm implementations should follow.

forward(x: torch.Tensor, /) torch.Tensor#

Forward method for a LayerNorm implementation.

class core.transformer.torch_norm.LayerNormBuilder#

Bases: typing.Protocol

A protocol showing how Modules are expected to construct LayerNorms.

__call__(
*,
config: megatron.core.transformer.TransformerConfig,
hidden_size: int,
eps: float,
) core.transformer.torch_norm.LayerNormInterface#
class core.transformer.torch_norm.WrappedTorchNorm#

A conditional wrapper to initialize an instance of PyTorch’s LayerNorm or RMSNorm based on input

__new__(
config: megatron.core.transformer.TransformerConfig,
hidden_size: int,
eps: float = 1e-05,
persist_layer_norm: bool = False,
zero_centered_gamma: bool = False,
normalization: str = 'LayerNorm',
) core.transformer.torch_norm.LayerNormInterface#
class core.transformer.torch_norm.L2Norm(hidden_size: int, eps: float = 1e-06, **kwargs)#

Bases: torch.nn.Module, core.transformer.torch_norm.LayerNormInterface

Applies L2 normalization to the input tensor along the last dimension.

This module normalizes the input tensor such that the mean of the squared values along the last dimension is 1 (within a small epsilon for numerical stability).

Parameters:
  • hidden_size (int) – Expected input shape for normalization (not used internally).

  • eps (float, optional) – A small value added to the denominator for numerical stability. Default: 1e-6.

Initialization

_norm(x: torch.Tensor) torch.Tensor#

Performs the actual L2 normalization.

Parameters:

x (torch.Tensor) – The input tensor to normalize.

Returns:

The L2-normalized tensor.

Return type:

torch.Tensor

forward(x: torch.Tensor) torch.Tensor#

Forward pass of the L2Norm module.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

L2-normalized tensor with the same dtype as input.

Return type:

torch.Tensor