bridge.training.state
#
Module Contents#
Classes#
Dataclass to hold the state of the training process. |
|
Dataclass to hold state specific to fault tolerance mechanisms. |
|
Manages the global state of the training process. |
Functions#
Patch to write timers to wandb for Megatron Core Timers. |
API#
- class bridge.training.state.TrainState#
Bases:
torch.distributed.checkpoint.stateful.Stateful
Dataclass to hold the state of the training process.
Inherits from Stateful for distributed checkpointing compatibility. Tracks iteration count, consumed samples, flags for train/valid/test phases, and floating-point operations.
- step: int#
0
- consumed_train_samples: int#
0
- skipped_train_samples: int#
0
- consumed_valid_samples: int#
0
- floating_point_operations_so_far: int#
0
- do_train: bool#
False
- do_valid: bool#
False
- do_test: bool#
False
- state_dict() dict[str, torch.Tensor] #
Serializes the training state into a dictionary of tensors.
Conforms to the Stateful interface for distributed checkpointing.
- Returns:
A dictionary where keys are state variable names and values are their corresponding tensor representations.
- load_state_dict(state_dict: dict[str, torch.Tensor]) None #
Load the training state from a state dictionary.
- Parameters:
state_dict – A dictionary containing the state variables as tensors.
- class bridge.training.state.FaultToleranceState#
Dataclass to hold state specific to fault tolerance mechanisms.
- ft_state_path: Optional[str]#
None
- is_persistent_chkpt_loaded: bool#
False
- is_async_chkpt_enabled: bool#
False
- is_calculating_timeouts: bool#
False
- is_setup_section_open: bool#
False
- seen_checkpoints_cnt: int#
0
- seen_tr_iters_cnt: int#
0
- curr_eval_iter_idx: int#
0
- class bridge.training.state.GlobalState#
Manages the global state of the training process.
Provides access to configuration, tokenizer, loggers, timers, training state, fault tolerance state, signal handler, and straggler detector through properties with lazy initialization.
Initialization
Initializes the GlobalState object.
- property cfg: Optional[megatron.bridge.training.config.ConfigContainer]#
The main configuration container object.
- property tokenizer: Any#
The tokenizer instance, lazily built based on the config.
- property tensorboard_logger: Optional[torch.utils.tensorboard.writer.SummaryWriter]#
The TensorBoard SummaryWriter instance, lazily initialized for rank N-1.
- property wandb_logger: Optional[Any]#
The Weights & Biases logger instance, lazily initialized for rank N-1.
- property timers: megatron.core.timers.Timers#
The Megatron Timers instance used for tracking execution times.
- property train_state: bridge.training.state.TrainState#
The TrainState object holding training progress information.
- property fault_tolerance_state: bridge.training.state.FaultToleranceState#
The FaultToleranceState object holding FT-specific information.
- property signal_handler: megatron.bridge.training.utils.sig_utils.DistributedSignalHandler#
The DistributedSignalHandler instance for graceful shutdown.
- property straggler_timer: megatron.core.utils.StragglerDetector#
The StragglerDetector instance for tracking slow GPUs.
- property async_calls_queue: Optional[megatron.core.dist_checkpointing.strategies.async_utils.AsyncCallsQueue]#
The AsyncCallsQueue instance for handling asynchronous checkpoint saves.
Creates a persistent AsyncCallsQueue when async_save is enabled in the checkpoint config. Returns None if async_save is disabled.
- property nvrx_straggler_manager: Optional[megatron.bridge.training.nvrx_straggler.NVRxStragglerDetectionManager]#
The NVRx straggler detection manager, if enabled.
- property energy_monitor: Optional[megatron.core.energy_monitor.EnergyMonitor]#
The EnergyMonitor instance for tracking energy consumption.
- _set_signal_handler() None #
Initializes the distributed signal handler based on the configuration.
- bridge.training.state._timers_write_to_wandb(
- self: megatron.core.timers.Timers,
- names: list[str],
- writer: Any,
- iteration: int,
- normalizer: float = 1.0,
- reset: bool = True,
- barrier: bool = False,
Patch to write timers to wandb for Megatron Core Timers.