nemo_automodel.components.training.step_scheduler#

Module Contents#

Classes#

StepScheduler

Scheduler for managing gradient accumulation and checkpointing steps.

Functions#

_calculate_max_steps

Calculate the maximum number of steps.

Data#

API#

nemo_automodel.components.training.step_scheduler.logger#

‘getLogger(…)’

nemo_automodel.components.training.step_scheduler._calculate_max_steps(
num_epochs: int,
epoch_len: Optional[int],
default_max_steps: int = 9223372036854775807,
) int#

Calculate the maximum number of steps.

class nemo_automodel.components.training.step_scheduler.StepScheduler(
global_batch_size: int,
local_batch_size: int,
dp_size: int,
dataloader: Optional[int],
ckpt_every_steps: Optional[int] = None,
val_every_steps: Optional[int] = None,
start_step: int = 0,
start_epoch: int = 0,
num_epochs: int = 10,
max_steps: int = None,
)#

Bases: torch.distributed.checkpoint.stateful.Stateful

Scheduler for managing gradient accumulation and checkpointing steps.

Initialization

Initialize the StepScheduler.

Parameters:
  • global_batch_size (int) – Number of steps for gradient accumulation.

  • local_batch_size (int) – Number of steps for gradient accumulation.

  • dp_size (int) – Number of steps for gradient accumulation.

  • ckpt_every_steps (Optional[int]) – Frequency of checkpoint steps.

  • dataloader (Optional[int]) – The training dataloader.

  • val_every_steps (int) – Number of training steps between validation.

  • start_step (int) – Initial global step.

  • start_epoch (int) – Initial epoch.

  • num_epochs (int) – Total number of epochs.

  • max_steps (int) – Total number of steps to run. Default is 2^63-1.

__iter__()#

Iterates over dataloader while keeping track of counters.

Raises:

StopIteration – If the dataloader was exhausted or max_steps was reached.

Yields:

dict – batch

set_epoch(epoch: int)#

Set the epoch for the sampler.

property is_val_step#

Returns whether this step needs to call the validation.

property is_ckpt_step#

Returns whether this step needs to call the checkpoint saving.

Returns:

if true, the checkpoint should run.

Return type:

bool

property is_last_step#

Returns whether the training is finished.

property is_last_batch#

Returns whether this is the last batch for this epoch.

property sigterm_received#

Returns whether SIGTERM was received.

property epochs#

Epoch iterator.

Yields:

iterator – over epochs

state_dict()#

Get the current state of the scheduler.

Returns:

Current state with ‘step’ and ‘epoch’ keys.

Return type:

dict

load_state_dict(s)#

Load the scheduler state from a dictionary.

Parameters:

s (dict) – Dictionary containing ‘step’ and ‘epoch’.