core.optimizer.clip_grads#

Gradient clipping.

Module Contents#

Functions#

get_grad_norm_fp32

Calculate the p-norm of gradients in FP32 precision.

clip_grad_by_total_norm_fp32

Clips the gradients of an iterable of parameters in FP32 by total norm.

count_zeros_fp32

Counts the number of zero values in the gradients of the given parameters.

API#

core.optimizer.clip_grads.get_grad_norm_fp32(
grads_for_norm: Union[List[torch.Tensor], torch.Tensor],
norm_type: Union[int, float] = 2,
grad_stats_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) float#

Calculate the p-norm of gradients in FP32 precision.

This function is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and extends it with functionality to handle model-parallel parameters. It ensures that the norm is correctly computed and reduced across the specified process group (typically the model-parallel group for non-distributed optimizers or the entire world for distributed optimizers).

Parameters:
  • grads_for_norm (Union[List[torch.Tensor], torch.Tensor]) – An iterable of Tensors or a single Tensor used to calculate the gradient norm.

  • norm_type (Union[int, float]) – The type of the p-norm to use. Can be ‘inf’ for infinity norm. Defaults to 2.

  • grad_stats_parallel_group (ProcessGroup, optional) – The process group used for reducing gradient statistics (e.g., norms and zero counts).

Returns:

The total norm of the parameters, treated as a single vector.

Return type:

float

core.optimizer.clip_grads.clip_grad_by_total_norm_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor],
max_norm: Union[int, float],
total_norm: float,
use_decoupled_grad: bool = False,
)#

Clips the gradients of an iterable of parameters in FP32 by total norm.

Note that the gradients are modified in-place.

Parameters:
  • parameters (Union[List[torch.Tensor], torch.Tensor]) – An iterable of Tensors or a single Tensor that will have gradients normalized.

  • max_norm (Union[int, float]) – The maximum permissible total norm of the gradients.

  • total_norm (float) – The current total norm of the gradients.

  • use_decoupled_grad (bool, optional) – Whether to read from the ‘.decoupled_grad’ attribute instead of the standard ‘.grad’. Defaults to False.

core.optimizer.clip_grads.count_zeros_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor],
grad_stats_parallel_group: torch.distributed.ProcessGroup,
use_decoupled_grad: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
) float#

Counts the number of zero values in the gradients of the given parameters.

The count is performed in FP32. This method filters parameters to ensure gradients are not double-counted by checking if the gradient is not None, the parameter is not shared, and the parameter is not a replica due to tensor model parallelism. It also handles parameters managed by Megatron FSDP specifically.

Parameters:
  • parameters (Union[List[torch.Tensor], torch.Tensor]) – An iterable of Tensors or a single Tensor whose gradients will be checked for zeros.

  • grad_stats_parallel_group (ProcessGroup) – The process group used for reducing the zero count across distributed ranks.

  • use_decoupled_grad (bool, optional) – If True, reads from the ‘.decoupled_grad’ attribute instead of the standard ‘.grad’. Defaults to False.

Returns:

The total number of zeros in the gradients across the process group.

Return type:

float