Source code for nemo_automodel.components.training.step_scheduler

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from torch.distributed.checkpoint.stateful import Stateful


[docs] class StepScheduler(Stateful): """ Scheduler for managing gradient accumulation and checkpointing steps. """ def __init__( self, grad_acc_steps: int, ckpt_every_steps: int, dataloader: Optional[int], val_every_steps: Optional[int] = None, start_step: int = 0, start_epoch: int = 0, num_epochs: int = 10, max_steps: int = 9223372036854775807, ): """ Initialize the StepScheduler. Args: grad_acc_steps (int): Number of steps for gradient accumulation. ckpt_every_steps (int): Frequency of checkpoint steps. dataloader (Optional[int]): The training dataloader. val_every_steps (int): Number of training steps between validation. start_step (int): Initial global step. start_epoch (int): Initial epoch. num_epochs (int): Total number of epochs. max_steps (int): Total number of steps to run. Default is 2^63-1. """ self.grad_acc_steps = grad_acc_steps assert grad_acc_steps > 0, "grad_acc_steps must be greater than 0" self.ckpt_every_steps = ckpt_every_steps assert ckpt_every_steps > 0, "ckpt_every_steps must be greater than 0" self.dataloader = dataloader self.step = start_step assert start_step >= 0, "start_step must be greater than or equal to 0" self.epoch = start_epoch assert start_epoch >= 0, "start_epoch must be greater than or equal to 0" self.num_epochs = num_epochs assert num_epochs > 0, "num_epochs must be greater than 0" self.epoch_len = len(dataloader) self.val_every_steps = val_every_steps assert val_every_steps is None or val_every_steps > 0, "val_every_steps must be greater than 0 if not None" self.max_steps = max_steps assert max_steps > 0, "max_steps must be greater than 0"
[docs] def __iter__(self): """ Iterates over dataloader while keeping track of counters. Raises: StopIteration: If the dataloader was exhausted or max_steps was reached. Yields: dict: batch """ if self.step >= self.max_steps: return batch_buffer = [] for batch in self.dataloader: batch_buffer.append(batch) if len(batch_buffer) == self.grad_acc_steps: self.step += 1 yield batch_buffer batch_buffer = [] if self.step >= self.max_steps: return if batch_buffer: self.step += 1 yield batch_buffer self.epoch += 1
[docs] def set_epoch(self, epoch: int): """ Set the epoch for the sampler. """ self.epoch = epoch if hasattr(getattr(self.dataloader, "sampler", None), "set_epoch"): self.dataloader.sampler.set_epoch(epoch)
@property def is_val_step(self): """ Returns whether this step needs to call the validation. """ is_val = False if self.val_every_steps and self.val_every_steps > 0: is_val = (self.step % self.val_every_steps) == 0 return is_val @property def is_ckpt_step(self): """ Returns whether this step needs to call the checkpoint saving. Returns: bool: if true, the checkpoint should run. """ batch_idx = self.step % self.epoch_len last_batch = self.epoch_len is not None and batch_idx == self.epoch_len - 1 finished = self.step >= self.max_steps return ((self.step % self.ckpt_every_steps) == 0 and self.step != 0) or last_batch or finished @property def epochs(self): """ Epoch iterator. Yields: iterator: over epochs """ epoch = self.epoch for e in range(epoch, self.num_epochs): if self.step >= self.max_steps: return yield e
[docs] def state_dict(self): """ Get the current state of the scheduler. Returns: dict: Current state with 'step' and 'epoch' keys. """ return {"step": self.step, "epoch": self.epoch}
[docs] def load_state_dict(self, s): """ Load the scheduler state from a dictionary. Args: s (dict): Dictionary containing 'step' and 'epoch'. """ self.step, self.epoch = s["step"], s["epoch"]