nemo_automodel.components.training.step_scheduler#
Module Contents#
Classes#
Scheduler for managing gradient accumulation and checkpointing steps. |
Functions#
Calculate the maximum number of steps. |
Data#
API#
- nemo_automodel.components.training.step_scheduler.logger#
‘getLogger(…)’
- nemo_automodel.components.training.step_scheduler._calculate_max_steps(
- num_epochs: int,
- epoch_len: Optional[int],
- default_max_steps: int = 9223372036854775807,
Calculate the maximum number of steps.
- class nemo_automodel.components.training.step_scheduler.StepScheduler(
- global_batch_size: int,
- local_batch_size: int,
- dp_size: int,
- dataloader: Optional[int],
- ckpt_every_steps: Optional[int] = None,
- val_every_steps: Optional[int] = None,
- start_step: int = 0,
- start_epoch: int = 0,
- num_epochs: int = 10,
- max_steps: int = None,
Bases:
torch.distributed.checkpoint.stateful.StatefulScheduler for managing gradient accumulation and checkpointing steps.
Initialization
Initialize the StepScheduler.
- Parameters:
global_batch_size (int) – Number of steps for gradient accumulation.
local_batch_size (int) – Number of steps for gradient accumulation.
dp_size (int) – Number of steps for gradient accumulation.
ckpt_every_steps (Optional[int]) – Frequency of checkpoint steps.
dataloader (Optional[int]) – The training dataloader.
val_every_steps (int) – Number of training steps between validation.
start_step (int) – Initial global step.
start_epoch (int) – Initial epoch.
num_epochs (int) – Total number of epochs.
max_steps (int) – Total number of steps to run. Default is 2^63-1.
- __iter__()#
Iterates over dataloader while keeping track of counters.
- Raises:
StopIteration – If the dataloader was exhausted or max_steps was reached.
- Yields:
dict – batch
- set_epoch(epoch: int)#
Set the epoch for the sampler.
- property is_val_step#
Returns whether this step needs to call the validation.
- property is_ckpt_step#
Returns whether this step needs to call the checkpoint saving.
- Returns:
if true, the checkpoint should run.
- Return type:
bool
- property is_last_step#
Returns whether the training is finished.
- property is_last_batch#
Returns whether this is the last batch for this epoch.
- property sigterm_received#
Returns whether SIGTERM was received.
- property epochs#
Epoch iterator.
- Yields:
iterator – over epochs
- state_dict()#
Get the current state of the scheduler.
- Returns:
Current state with ‘step’ and ‘epoch’ keys.
- Return type:
dict
- load_state_dict(s)#
Load the scheduler state from a dictionary.
- Parameters:
s (dict) – Dictionary containing ‘step’ and ‘epoch’.