Source code for nemo_automodel.training.step_scheduler

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from torch.distributed.checkpoint.stateful import Stateful


[docs] class StepScheduler(Stateful): """ Scheduler for managing gradient accumulation and checkpointing steps. """ def __init__(self, grad_acc_steps: int, ckpt_every_steps: int, dataloader: Optional[int], val_every_steps: Optional[int] = None, start_step: int = 0, start_epoch: int = 0, num_epochs: int = 10, max_steps: Optional[int] = None): """ Initialize the StepScheduler. Args: grad_acc_steps (int): Number of steps for gradient accumulation. ckpt_every_steps (int): Frequency of checkpoint steps. dataloader (Optional[int]): The training dataloader. val_every_steps (int): Number of training steps between validation. start_step (int): Initial global step. start_epoch (int): Initial epoch. num_epochs (int): Total number of epochs. max_steps (int): Total number of steps to run. """ self.grad_acc_steps = grad_acc_steps self.ckpt_every_steps = ckpt_every_steps self.dataloader = dataloader self.step = start_step self.epoch = start_epoch self.num_epochs = num_epochs self.epoch_len = len(dataloader) self.grad_step = 0 # number of optimizer steps taken self.val_every_steps = val_every_steps self.max_steps = max_steps
[docs] def __iter__(self): """ Iterates over dataloader while keeping track of counters. Raises: StopIteration: If the dataloader was exhausted or max_steps was reached. Yields: dict: batch """ for batch in self.dataloader: self.step += 1 if isinstance(self.max_steps, int) and self.step > self.max_steps: return yield batch
[docs] def set_epoch(self, epoch: int): """ Set the epoch for the dataloader. """ self.epoch = epoch self.dataloader.sampler.set_epoch(epoch)
@property def is_optim_step(self): """ Returns whether this step needs to call the optimizer step. Returns: bool: if true, the optimizer should run. """ is_grad = (self.step % self.grad_acc_steps) == 0 self.grad_step += int(is_grad) return is_grad @property def is_val_step(self): """ Returns whether this step needs to call the validation. """ is_val = False if self.val_every_steps and self.val_every_steps > 0 and self.is_optim_step: is_val = (self.grad_step % self.val_every_steps) == 0 return is_val @property def is_ckpt_step(self): """ Returns whether this step needs to call the checkpoint saving. Returns: bool: if true, the checkpoint should run. """ batch_idx = self.step % self.epoch_len last_batch = self.epoch_len is not None and batch_idx == self.epoch_len - 1 return ((self.step % self.ckpt_every_steps) == 0 and self.step != 0) or last_batch @property def epochs(self): """ Epoch iterator. Yields: iterator: over epochs """ yield from range(self.epoch, self.num_epochs)
[docs] def state_dict(self): """ Get the current state of the scheduler. Returns: dict: Current state with 'step' and 'epoch' keys. """ return {"step": self.step, "epoch": self.epoch}
[docs] def load_state_dict(self, s): """ Load the scheduler state from a dictionary. Args: s (dict): Dictionary containing 'step' and 'epoch'. """ self.step, self.epoch = s["step"], s["epoch"]