bridge.training.state#

Module Contents#

Classes#

TrainState

Dataclass to hold the state of the training process.

FaultToleranceState

Dataclass to hold state specific to fault tolerance mechanisms.

GlobalState

Manages the global state of the training process.

Functions#

_timers_write_to_wandb

Patch to write timers to wandb for Megatron Core Timers.

API#

class bridge.training.state.TrainState#

Bases: torch.distributed.checkpoint.stateful.Stateful

Dataclass to hold the state of the training process.

Inherits from Stateful for distributed checkpointing compatibility. Tracks iteration count, consumed samples, flags for train/valid/test phases, and floating-point operations.

step: int#

0

consumed_train_samples: int#

0

skipped_train_samples: int#

0

consumed_valid_samples: int#

0

floating_point_operations_so_far: int#

0

do_train: bool#

False

do_valid: bool#

False

do_test: bool#

False

state_dict() dict[str, torch.Tensor]#

Serializes the training state into a dictionary of tensors.

Conforms to the Stateful interface for distributed checkpointing.

Returns:

A dictionary where keys are state variable names and values are their corresponding tensor representations.

load_state_dict(state_dict: dict[str, torch.Tensor]) None#

Load the training state from a state dictionary.

Parameters:

state_dict – A dictionary containing the state variables as tensors.

class bridge.training.state.FaultToleranceState#

Dataclass to hold state specific to fault tolerance mechanisms.

ft_state_path: Optional[str]#

None

is_persistent_chkpt_loaded: bool#

False

is_async_chkpt_enabled: bool#

False

is_calculating_timeouts: bool#

False

is_setup_section_open: bool#

False

seen_checkpoints_cnt: int#

0

seen_tr_iters_cnt: int#

0

curr_eval_iter_idx: int#

0

class bridge.training.state.GlobalState#

Manages the global state of the training process.

Provides access to configuration, tokenizer, loggers, timers, training state, fault tolerance state, signal handler, and straggler detector through properties with lazy initialization.

Initialization

Initializes the GlobalState object.

property cfg: Optional[megatron.bridge.training.config.ConfigContainer]#

The main configuration container object.

property tokenizer: Any#

The tokenizer instance, lazily built based on the config.

property tensorboard_logger: Optional[torch.utils.tensorboard.writer.SummaryWriter]#

The TensorBoard SummaryWriter instance, lazily initialized for rank N-1.

property wandb_logger: Optional[Any]#

The Weights & Biases logger instance, lazily initialized for rank N-1.

property timers: megatron.core.timers.Timers#

The Megatron Timers instance used for tracking execution times.

property train_state: bridge.training.state.TrainState#

The TrainState object holding training progress information.

property fault_tolerance_state: bridge.training.state.FaultToleranceState#

The FaultToleranceState object holding FT-specific information.

property signal_handler: megatron.bridge.training.utils.sig_utils.DistributedSignalHandler#

The DistributedSignalHandler instance for graceful shutdown.

property straggler_timer: megatron.core.utils.StragglerDetector#

The StragglerDetector instance for tracking slow GPUs.

property async_calls_queue: Optional[megatron.core.dist_checkpointing.strategies.async_utils.AsyncCallsQueue]#

The AsyncCallsQueue instance for handling asynchronous checkpoint saves.

Creates a persistent AsyncCallsQueue when async_save is enabled in the checkpoint config. Returns None if async_save is disabled.

property nvrx_straggler_manager: Optional[megatron.bridge.training.nvrx_straggler.NVRxStragglerDetectionManager]#

The NVRx straggler detection manager, if enabled.

property energy_monitor: Optional[megatron.core.energy_monitor.EnergyMonitor]#

The EnergyMonitor instance for tracking energy consumption.

_set_signal_handler() None#

Initializes the distributed signal handler based on the configuration.

bridge.training.state._timers_write_to_wandb(
self: megatron.core.timers.Timers,
names: list[str],
writer: Any,
iteration: int,
normalizer: float = 1.0,
reset: bool = True,
barrier: bool = False,
) None#

Patch to write timers to wandb for Megatron Core Timers.