nemo_automodel.components.training.step_scheduler

View as Markdown

Module Contents

Classes

NameDescription
StepSchedulerScheduler for managing gradient accumulation and checkpointing steps.
StepSchedulerConfigUser-facing step scheduler configuration.

Functions

NameDescription
_calculate_max_stepsCalculate the maximum number of steps.
_calculate_num_epochsCalculate the number of epochs out of maximum number of steps.

Data

logger

API

class nemo_automodel.components.training.step_scheduler.StepScheduler(
global_batch_size: int,
local_batch_size: int,
dp_size: int,
dataloader: typing.Optional[int],
ckpt_every_steps: typing.Optional[int] = None,
save_checkpoint_every_epoch: bool = True,
val_every_steps: typing.Optional[int] = None,
log_remote_every_steps: int = 1,
loss_average_window_steps: int = 50,
gc_every_steps: typing.Optional[int] = None,
start_step: int = 0,
start_epoch: int = 0,
num_epochs: typing.Optional[int] = None,
max_steps: typing.Optional[int] = None
)

Bases: Stateful

Scheduler for managing gradient accumulation and checkpointing steps.

epoch_len
= ceil(len(dataloader) / self.grad_acc_steps)
epochs

Epoch iterator.

grad_acc_steps
= global_batch_size // (local_batch_size * dp_size)
is_ckpt_step

Returns whether this step needs to call the checkpoint saving.

is_gc_step

Returns whether this step needs to run manual garbage collection.

is_last_batch

Returns whether this is the last batch for this epoch.

is_last_step

Returns whether the current step is the final training step.

Training stops at whichever comes first: reaching max_steps or exhausting the configured number of epochs (see __iter__ and epochs). max_steps alone is therefore not enough to detect the end — a small dataset can run out of epochs long before max_steps is hit (e.g. max_steps=100 with only 60 steps’ worth of data). In that case the last batch of the last epoch is the final step. Detect it so the final checkpoint and consolidated export — which key off this flag (see is_ckpt_step and the recipes’ is_final_checkpoint) — are still written.

is_remote_logging_step

Returns whether this step should log to remote services (WandB, MLflow, etc.).

is_val_step

Returns whether this step needs to call the validation.

sig_handler
= DistributedSignalHandler().__enter__()
sigterm_received

Returns whether SIGTERM was received.

nemo_automodel.components.training.step_scheduler.StepScheduler.__iter__()

Iterates over dataloader while keeping track of counters.

Raises:

  • StopIteration: If the dataloader was exhausted or max_steps was reached.
nemo_automodel.components.training.step_scheduler.StepScheduler.load_state_dict(
s
)

Load the scheduler state from a dictionary.

Parameters:

s
dict

Dictionary containing ‘step’ and ‘epoch’.

nemo_automodel.components.training.step_scheduler.StepScheduler.set_epoch(
epoch: int
)

Set the epoch for the sampler.

nemo_automodel.components.training.step_scheduler.StepScheduler.state_dict()

Get the current state of the scheduler.

Returns:

Current state with ‘step’ and ‘epoch’ keys.

class nemo_automodel.components.training.step_scheduler.StepSchedulerConfig(
global_batch_size: int = 32,
num_epochs: int | None = 10,
max_steps: int | None = None,
ckpt_every_steps: int | None = 100,
save_checkpoint_every_epoch: bool = True,
val_every_steps: int | None = None,
log_remote_every_steps: int = 1,
loss_average_window_steps: int = 50,
gc_every_steps: int | None = None,
start_step: int = 0,
start_epoch: int = 0
)
Dataclass

User-facing step scheduler configuration.

These fields correspond to the YAML-configurable parameters of the training loop. Runtime-only values (dataloader, dp_size, local_batch_size) are passed separately to build_step_scheduler.

ckpt_every_steps
int | None = 100
gc_every_steps
int | None = None
global_batch_size
int = 32
log_remote_every_steps
int = 1
loss_average_window_steps
int = 50
max_steps
int | None = None
num_epochs
int | None = 10
save_checkpoint_every_epoch
bool = True
start_epoch
int = 0
start_step
int = 0
val_every_steps
int | None = None
nemo_automodel.components.training.step_scheduler.StepSchedulerConfig.build(
dataloader: torch.utils.data.DataLoader,
dp_group_size: int,
local_batch_size: int
) -> nemo_automodel.components.training.step_scheduler.StepScheduler

Build the step scheduler.

Parameters:

dataloader
DataLoader

The training dataloader.

dp_group_size
int

The size of the data parallel group.

local_batch_size
int

The size of the local batch.

Returns: StepScheduler

Configured StepScheduler.

nemo_automodel.components.training.step_scheduler._calculate_max_steps(
num_epochs: int,
epoch_len: typing.Optional[int],
default_max_steps: int = 9223372036854775807
) -> int

Calculate the maximum number of steps.

nemo_automodel.components.training.step_scheduler._calculate_num_epochs(
max_steps: typing.Optional[int],
epoch_len: typing.Optional[int],
default_num_epochs: int = 10
) -> int

Calculate the number of epochs out of maximum number of steps.

nemo_automodel.components.training.step_scheduler.logger = logging.getLogger(__name__)