nemo_automodel.components.distributed.grad_utils
#
Module Contents#
Functions#
Clips gradient of an iterable of parameters by total norm. |
|
Calculate the norm of gradients. |
API#
- nemo_automodel.components.distributed.grad_utils.clip_grad_by_total_norm_(
- parameters: Union[list[Union[torch.Tensor, torch.distributed.tensor.DTensor]], Union[torch.Tensor, torch.distributed.tensor.DTensor]],
- max_grad_norm: Union[int, float],
- total_norm: float,
- dtype: torch.dtype = torch.float32,
Clips gradient of an iterable of parameters by total norm.
Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L138
Note that the gradients are modified in place.
- Parameters:
parameters (Union[list[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]]) – An iterable of Tensors or DTensors, or a single Tensor or DTensor that will have gradients normalized.
max_grad_norm (Union[float, int]) – Maximum norm of the gradients.
total_norm (float) – The pre-computed total norm of the gradients to use for scaling.
- nemo_automodel.components.distributed.grad_utils.get_grad_norm(
- parameters: Union[list[Union[torch.Tensor, torch.distributed.tensor.DTensor]], Union[torch.Tensor, torch.distributed.tensor.DTensor]],
- dp_cp_group: torch.distributed.ProcessGroup,
- tp_group: torch.distributed.ProcessGroup,
- norm_type: Union[int, float] = 2,
- dtype: torch.dtype = torch.float32,
Calculate the norm of gradients.
Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L51
- Parameters:
parameters (Union[list[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]]) – An iterable of Tensors or DTensors, or a single Tensor or DTensor that will have gradient norm calculated.
dp_group (torch.distributed.ProcessGroup) – Process group for data parallel communication.
cp_group (torch.distributed.ProcessGroup) – Process group for context parallel communication.
tp_group (torch.distributed.ProcessGroup) – Process group for tensor parallel communication.
norm_type (Union[int, float]) – Type of the used p-norm. Can be
'inf'
for infinity norm.
- Returns:
Total norm of the gradients (viewed as a single vector)
- Return type:
float