core.transformer.torch_norm#

Module Contents#

Classes#

WrappedTorchNorm

A conditional wrapper to initialize an instance of PyTorch’s LayerNorm or RMSNorm based on input

L2Norm

Applies L2 normalization to the input tensor along the last dimension.

API#

class core.transformer.torch_norm.WrappedTorchNorm#

A conditional wrapper to initialize an instance of PyTorch’s LayerNorm or RMSNorm based on input

__new__(
config: megatron.core.transformer.TransformerConfig,
hidden_size: int,
eps: float = 1e-05,
persist_layer_norm: bool = False,
zero_centered_gamma: bool = False,
normalization: str = 'LayerNorm',
)#
class core.transformer.torch_norm.L2Norm(hidden_size: int, eps: float = 1e-06, **kwargs)#

Bases: torch.nn.Module

Applies L2 normalization to the input tensor along the last dimension.

This module normalizes the input tensor such that the mean of the squared values along the last dimension is 1 (within a small epsilon for numerical stability).

Parameters:
  • hidden_size (int) – Expected input shape for normalization (not used internally).

  • eps (float, optional) – A small value added to the denominator for numerical stability. Default: 1e-6.

Initialization

_norm(x)#

Performs the actual L2 normalization.

Parameters:

x (torch.Tensor) – The input tensor to normalize.

Returns:

The L2-normalized tensor.

Return type:

torch.Tensor

forward(x)#

Forward pass of the L2Norm module.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

L2-normalized tensor with the same dtype as input.

Return type:

torch.Tensor