core.transformer.torch_norm#
Module Contents#
Classes#
A conditional wrapper to initialize an instance of PyTorch’s
|
|
Applies L2 normalization to the input tensor along the last dimension. |
API#
- class core.transformer.torch_norm.WrappedTorchNorm#
A conditional wrapper to initialize an instance of PyTorch’s
LayerNormorRMSNormbased on input- __new__(
- config: megatron.core.transformer.TransformerConfig,
- hidden_size: int,
- eps: float = 1e-05,
- persist_layer_norm: bool = False,
- zero_centered_gamma: bool = False,
- normalization: str = 'LayerNorm',
- class core.transformer.torch_norm.L2Norm(hidden_size: int, eps: float = 1e-06, **kwargs)#
Bases:
torch.nn.ModuleApplies L2 normalization to the input tensor along the last dimension.
This module normalizes the input tensor such that the mean of the squared values along the last dimension is 1 (within a small epsilon for numerical stability).
- Parameters:
hidden_size (int) – Expected input shape for normalization (not used internally).
eps (float, optional) – A small value added to the denominator for numerical stability. Default: 1e-6.
Initialization
- _norm(x)#
Performs the actual L2 normalization.
- Parameters:
x (torch.Tensor) – The input tensor to normalize.
- Returns:
The L2-normalized tensor.
- Return type:
torch.Tensor
- forward(x)#
Forward pass of the L2Norm module.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
L2-normalized tensor with the same dtype as input.
- Return type:
torch.Tensor