core.optimizer.clip_grads#
Gradient clipping.
Module Contents#
Functions#
Calculate the p-norm of gradients in FP32 precision. |
|
Clips the gradients of an iterable of parameters in FP32 by total norm. |
|
Counts the number of zero values in the gradients of the given parameters. |
API#
- core.optimizer.clip_grads.get_grad_norm_fp32(
- grads_for_norm: Union[List[torch.Tensor], torch.Tensor],
- norm_type: Union[int, float] = 2,
- grad_stats_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
Calculate the p-norm of gradients in FP32 precision.
This function is adapted from
torch.nn.utils.clip_grad.clip_grad_norm_and extends it with functionality to handle model-parallel parameters. It ensures that the norm is correctly computed and reduced across the specified process group (typically the model-parallel group for non-distributed optimizers or the entire world for distributed optimizers).- Parameters:
grads_for_norm (Union[List[torch.Tensor], torch.Tensor]) – An iterable of Tensors or a single Tensor used to calculate the gradient norm.
norm_type (Union[int, float]) – The type of the p-norm to use. Can be ‘inf’ for infinity norm. Defaults to 2.
grad_stats_parallel_group (ProcessGroup, optional) – The process group used for reducing gradient statistics (e.g., norms and zero counts).
- Returns:
The total norm of the parameters, treated as a single vector.
- Return type:
float
- core.optimizer.clip_grads.clip_grad_by_total_norm_fp32(
- parameters: Union[List[torch.Tensor], torch.Tensor],
- max_norm: Union[int, float],
- total_norm: float,
- use_decoupled_grad: bool = False,
Clips the gradients of an iterable of parameters in FP32 by total norm.
Note that the gradients are modified in-place.
- Parameters:
parameters (Union[List[torch.Tensor], torch.Tensor]) – An iterable of Tensors or a single Tensor that will have gradients normalized.
max_norm (Union[int, float]) – The maximum permissible total norm of the gradients.
total_norm (float) – The current total norm of the gradients.
use_decoupled_grad (bool, optional) – Whether to read from the ‘.decoupled_grad’ attribute instead of the standard ‘.grad’. Defaults to False.
- core.optimizer.clip_grads.count_zeros_fp32(
- parameters: Union[List[torch.Tensor], torch.Tensor],
- grad_stats_parallel_group: torch.distributed.ProcessGroup,
- use_decoupled_grad: bool = False,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Counts the number of zero values in the gradients of the given parameters.
The count is performed in FP32. This method filters parameters to ensure gradients are not double-counted by checking if the gradient is not None, the parameter is not shared, and the parameter is not a replica due to tensor model parallelism. It also handles parameters managed by Megatron FSDP specifically.
- Parameters:
parameters (Union[List[torch.Tensor], torch.Tensor]) – An iterable of Tensors or a single Tensor whose gradients will be checked for zeros.
grad_stats_parallel_group (ProcessGroup) – The process group used for reducing the zero count across distributed ranks.
use_decoupled_grad (bool, optional) – If True, reads from the ‘.decoupled_grad’ attribute instead of the standard ‘.grad’. Defaults to False.
- Returns:
The total number of zeros in the gradients across the process group.
- Return type:
float