nemo_automodel.components.distributed.grad_utils

View as Markdown

Module Contents

Functions

NameDescription
clip_grad_by_total_norm_Clips gradient of an iterable of parameters by total norm.
get_grad_normCalculate the norm of gradients.

API

nemo_automodel.components.distributed.grad_utils.clip_grad_by_total_norm_(
parameters: typing.Union[list[typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]],
max_grad_norm: typing.Union[int, float],
total_norm: float,
dtype: torch.dtype = torch.float32
)

Clips gradient of an iterable of parameters by total norm.

Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L138

Note that the gradients are modified in place.

Parameters:

parameters
Union[list[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]]

An iterable of Tensors or DTensors, or a single Tensor or DTensor that will have gradients normalized.

max_grad_norm
Union[float, int]

Maximum norm of the gradients.

total_norm
float

The pre-computed total norm of the gradients to use for scaling.

nemo_automodel.components.distributed.grad_utils.get_grad_norm(
parameters: typing.Union[list[typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]], typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]],
dp_cp_group: torch.distributed.ProcessGroup,
tp_group: torch.distributed.ProcessGroup,
norm_type: typing.Union[int, float] = 2,
dtype: torch.dtype = torch.float32
) -> float

Calculate the norm of gradients.

Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/a695b2bd2a0ca9ca63385a48c41a1c5a033cdd1e/megatron/core/optimizer/clip_grads.py#L51

Parameters:

parameters
Union[list[Union[torch.Tensor, DTensor]], Union[torch.Tensor, DTensor]]

An iterable of Tensors or DTensors, or a single Tensor or DTensor that will have gradient norm calculated.

dp_group
torch.distributed.ProcessGroup

Process group for data parallel communication.

cp_group
torch.distributed.ProcessGroup

Process group for context parallel communication.

tp_group
torch.distributed.ProcessGroup

Process group for tensor parallel communication.

norm_type
Union[int, float]Defaults to 2

Type of the used p-norm. Can be 'inf' for infinity norm.

Returns: float

Total norm of the gradients (viewed as a single vector)