nemo_automodel.components.training.utils

View as Markdown

Module Contents

Classes

NameDescription
ScopedModuleOffloadingContext manager that temporarily moves a module between CPU and CUDA.

Functions

NameDescription
_clip_grad_norm_impl-
clip_grad_normCommon gradient clipping helper.
count_tail_paddingCounts the total number of padding token in the tail of labels
move_to_deviceMove a model and its buffers to a device and release stale CUDA cache.
prepare_after_first_microbatchDisable first-microbatch flag after the first forward-backward pass.
prepare_for_final_backwardPrepare model parts before the final backward pass.
prepare_for_grad_accumulationPrepare model parts before starting gradient accumulation.
scale_grads_and_clip_grad_normScale gradients for PP/EP in a single pass, then clip.

Data

_TE_EXPERT_PARAM_PATTERN

API

class nemo_automodel.components.training.utils.ScopedModuleOffloading(
model,
enabled = False
)

Context manager that temporarily moves a module between CPU and CUDA.

nemo_automodel.components.training.utils.ScopedModuleOffloading.__enter__()
nemo_automodel.components.training.utils.ScopedModuleOffloading.__exit__(
exc_type,
exc_val,
exc_tb
)
nemo_automodel.components.training.utils._clip_grad_norm_impl(
parameters: torch.Tensor | typing.Iterable[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: bool | None = None,
pp_mesh: torch.distributed.device_mesh.DeviceMesh | None = None
) -> torch.Tensor
nemo_automodel.components.training.utils.clip_grad_norm(
max_grad_norm: float | None,
model_parts: list[torch.nn.Module],
norm_type: float = 2.0,
pp_enabled: bool = False,
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
pp_axis_name: str | None = None,
foreach: bool = True,
use_torch_clip_grad_norm: bool = False
)

Common gradient clipping helper.

Handles all parallelism strategies (TP, PP, EP/MoE) with automatic sharding-aware grouping. Returns the gradient norm as a float, or 0.0 if clipping is skipped.

This function automatically:

  • Groups parameters by sharding pattern (device mesh + placements)
  • Computes norms correctly across different sharding strategies
  • Handles MoE with separate DP/EP meshes
  • Reduces norms across pipeline parallel stages when enabled

Parameters:

max_grad_norm
float | None

Maximum gradient norm. If None, skips clipping.

model_parts
list[torch.nn.Module]

List of model modules to clip.

norm_type
floatDefaults to 2.0

Type of norm to use (default: 2.0 for L2).

pp_enabled
boolDefaults to False

Whether pipeline parallelism is enabled.

device_mesh
DeviceMesh | NoneDefaults to None

Device mesh for parallelism.

moe_mesh

MoE-specific device mesh (unused, kept for API compatibility).

ep_axis_name

Expert parallel axis name (unused, kept for API compatibility).

pp_axis_name
str | NoneDefaults to None

Pipeline parallel axis name.

foreach
boolDefaults to True

Whether to use foreach implementation for clipping.

use_torch_clip_grad_norm
boolDefaults to False

Use PyTorch’s optimized regular-tensor clipping path when possible.

Returns:

Total gradient norm as a float.

nemo_automodel.components.training.utils.count_tail_padding(
labels,
ignore_label = -100
)

Counts the total number of padding token in the tail of labels

e.g. labels = torch.tensor([ [-100, 1, 1, -100, -100], # 2 tail -100s [-100, -100, 2, 3, 4], # 0 tail -100s [5, 6, -100, -100, -100], # 3 tail -100s ]) count_tail_padding will return 5. Please do note there’s more than 5 ignore labels. Args: labels (torch.Tensor): the labels ignore_label (int, optional): ignore label index. Defaults to -100.

Returns:

total number of ignored tokens in the labels input.

nemo_automodel.components.training.utils.move_to_device(
model,
device
)

Move a model and its buffers to a device and release stale CUDA cache.

nemo_automodel.components.training.utils.prepare_after_first_microbatch()

Disable first-microbatch flag after the first forward-backward pass.

Called after the first microbatch in gradient accumulation so that subsequent microbatches reuse cached FP8 weights instead of re-quantizing.

nemo_automodel.components.training.utils.prepare_for_final_backward(
model_parts: list[torch.nn.Module],
pp_enabled: bool = False
)

Prepare model parts before the final backward pass.

This is typically called before the final gradient accumulation step to prepare FSDP states for gradient synchronization and resharding.

Parameters:

model_parts
list[torch.nn.Module]

List of model parts (modules) to prepare.

pp_enabled
boolDefaults to False

Whether pipeline parallelism is enabled.

nemo_automodel.components.training.utils.prepare_for_grad_accumulation(
model_parts: list[torch.nn.Module],
pp_enabled: bool = False
)

Prepare model parts before starting gradient accumulation.

This is typically called once at the start of gradient accumulation to prepare FSDP states for the upcoming forward and backward passes.

Parameters:

model_parts
list[torch.nn.Module]

List of model parts (modules) to prepare.

pp_enabled
boolDefaults to False

Whether pipeline parallelism is enabled.

nemo_automodel.components.training.utils.scale_grads_and_clip_grad_norm(
max_grad_norm: float | None,
model_parts: list[torch.nn.Module],
norm_type: float = 2.0,
pp_enabled: bool = False,
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
moe_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
ep_axis_name: str | None = None,
pp_axis_name: str | None = None,
foreach: bool = True,
num_label_tokens: int | None = None,
dp_group_size: int | None = None,
use_torch_clip_grad_norm: bool = False
)

Scale gradients for PP/EP in a single pass, then clip.

  • PP scaling: divide all local grads by (num_label_tokens / dp_group_size).
  • EP scaling: for parameters on the expert axis, divide grads by (dp_group_size / ep_shard_size).
  • Finally, perform grad clipping with PP/EP-aware reductions.
nemo_automodel.components.training.utils._TE_EXPERT_PARAM_PATTERN = re.compile('(^|\\.)mlp\\.experts\\.(gate_up_linear|down_linear)\\.(weight|bias)\...