bridge.training.utils.train_utils
#
Module Contents#
Functions#
Check if a parameter is marked as not shared. |
|
Calculate the L2 norm of model parameters across all GPUs. |
|
Calculate l2 norm of DTensor parameters. |
|
Calculates the max of a stat across the model parallel group. |
|
Performs a logical AND operation across the model parallel group. |
|
Log training stats (losses, learning rate, timings, etc.). |
|
Logs the memory usage of the model. This metric calls the torch memory stats API for CUDA and reports different memory statistics. The following statistics are recorded: +ββββββββ+βββββββββββββββββββββββββββββ-+ | Statistic | Description | +========================+========================================================================================+ | current_allocated_mem | Current amount of allocated memory in gigabytes. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | current_active_mem | Current amount of active memory in gigabytes at the time of recording. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | current_inactive_mem | Current amount of inactive, non-releaseable memory in gigabytes. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | current_reserved_mem | Current amount of reserved memory in gigabytes at the time of recording. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | peak_allocated_mem | Peak amount of allocated memory in gigabytes. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | peak_active_mem | Peak amount of active memory in gigabytes at the time of recording. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | peak_inactive_mem | Peak amount of inactive, non-releaseable memory in gigabytes at the time of recording. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | peak_reserved_mem | Peak amount of reserved memory in gigabytes at the time of recording. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | alloc_retries | Number of failed cudaMalloc calls that result in a cache flush and retry. | +ββββββββ+βββββββββββββββββββββββββββββ-+ |
|
Computes and logs the L2 norm of gradients.
L2 norms are calculated after the reduction of gradients across GPUs. This function iterates over the parameters
of the model and may cause a reduction in throughput while training large models. In order to ensure the
correctness of the norm, this function should be called after gradient unscaling in cases where gradients
are scaled.
The following statistics are recorded:
+ββββββββββββββββ+ββββββββββββββββββ+
| Key | Logged data |
+===============================================+=====================================================+
| | L2 norm of the gradients of all parameters in |
| |
|
Estimates total training time.
The training time is computed by taking the time elapsed for the current duration and multiplying
out to the full extended length of the training run.
This metric provides a best attempt estimate. This estimate may be inaccurate if throughput
changes through training or other significant changes are made to the model or dataloader.
The following statistics are recorded:
+ββββββββββ+ββββββββββ-+
| Key | Logged data |
+=============================+===============================+
| |
|
Logs the training throughput and utilization.
The training throughput is logged on the event once we have reached the |
|
Convenience function to check and inject GlobalState in one call. |
|
Check if a forward step function needs GlobalState injection. |
|
Optionally inject GlobalState into forward_step functions that expect it. |
Data#
API#
- bridge.training.utils.train_utils.MEMORY_KEYS: dict[str, str]#
None
Check if a parameter is marked as not shared.
- Parameters:
param (torch.nn.Parameter) β The parameter to check.
- Returns:
True if the parameter does not have a βsharedβ attribute or if param.shared is False.
- Return type:
bool
- bridge.training.utils.train_utils.calc_params_l2_norm(
- model: Union[megatron.core.transformer.module.MegatronModule, list[megatron.core.transformer.module.MegatronModule]],
- model_config: Any,
- use_megatron_fsdp: bool = False,
- force_create_fp32_copy: bool = False,
Calculate the L2 norm of model parameters across all GPUs.
Handles parameter sharding (DP, TP, PP, EP) and different parameter types (dense, MoE, sharded main params).
- Parameters:
model (Union[torch.nn.Module, list[torch.nn.Module]]) β The model or list of model chunks.
model_config β The model configuration object.
force_create_fp32_copy (bool, optional) β If True, always creates an FP32 copy for norm calculation, ignoring potential
main_param
attributes. Defaults to False.
- Returns:
The L2 norm of all parameters.
- Return type:
float
- bridge.training.utils.train_utils.calc_dtensor_params_l2_norm(params)#
Calculate l2 norm of DTensor parameters.
- bridge.training.utils.train_utils.reduce_max_stat_across_model_parallel_group(
- stat: Optional[float],
Calculates the max of a stat across the model parallel group.
Handles cases where some ranks might have the stat as None (e.g., grad norm on ranks without an optimizer).
- Parameters:
stat (float) β The statistic value (or None) on the current rank.
- Returns:
The maximum value of the statistic across the model parallel group, or None if all ranks had None.
- Return type:
float
- bridge.training.utils.train_utils.logical_and_across_model_parallel_group(input: bool) bool #
Performs a logical AND operation across the model parallel group.
- Parameters:
input (bool) β The boolean value on the current rank.
- Returns:
The result of the logical AND across all ranks in the group.
- Return type:
bool
- bridge.training.utils.train_utils.training_log(
- loss_dict: dict[str, torch.Tensor],
- total_loss_dict: dict[str, Any],
- learning_rate: Optional[float],
- decoupled_learning_rate: Optional[float],
- loss_scale: float,
- report_memory_flag: bool,
- skipped_iter: int,
- grad_norm: Optional[float],
- params_norm: Optional[float],
- num_zeros_in_grad: Optional[int],
- config: megatron.bridge.training.config.ConfigContainer,
- global_state: megatron.bridge.training.state.GlobalState,
- history_wct: list,
- model: list[megatron.core.transformer.module.MegatronModule],
Log training stats (losses, learning rate, timings, etc.).
Aggregates losses, logs metrics to TensorBoard and WandB (if enabled), and prints a formatted log string to the console on the last rank.
- Parameters:
loss_dict (dict[str, torch.Tensor]) β Dictionary of losses for the current step.
total_loss_dict (dict[str, Any]) β Dictionary to accumulate losses and stats across logging intervals.
learning_rate (Optional[float]) β Current learning rate.
decoupled_learning_rate (Optional[float]) β Current decoupled learning rate (if used).
loss_scale (float) β Current loss scale value.
report_memory_flag (bool) β Flag to indicate if memory usage should be reported.
skipped_iter (int) β 1 if the iteration was skipped, 0 otherwise.
grad_norm (Optional[float]) β Gradient norm if computed, else None.
params_norm (Optional[float]) β Parameter L2 norm if computed, else None.
num_zeros_in_grad (Optional[int]) β Number of zeros in gradient if computed, else None.
config β The main configuration container.
global_state β The global training state.
history_wct (list) β list of elapsed time per each iteration.
model (list[MegatronModule]) β megatron model state.
- Returns:
The updated report_memory_flag.
- Return type:
bool
- bridge.training.utils.train_utils.report_memory(memory_keys: Optional[dict[str, str]]) dict #
Logs the memory usage of the model. This metric calls the torch memory stats API for CUDA and reports different memory statistics. The following statistics are recorded: +ββββββββ+βββββββββββββββββββββββββββββ-+ | Statistic | Description | +========================+========================================================================================+ | current_allocated_mem | Current amount of allocated memory in gigabytes. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | current_active_mem | Current amount of active memory in gigabytes at the time of recording. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | current_inactive_mem | Current amount of inactive, non-releaseable memory in gigabytes. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | current_reserved_mem | Current amount of reserved memory in gigabytes at the time of recording. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | peak_allocated_mem | Peak amount of allocated memory in gigabytes. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | peak_active_mem | Peak amount of active memory in gigabytes at the time of recording. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | peak_inactive_mem | Peak amount of inactive, non-releaseable memory in gigabytes at the time of recording. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | peak_reserved_mem | Peak amount of reserved memory in gigabytes at the time of recording. | +ββββββββ+βββββββββββββββββββββββββββββ-+ | alloc_retries | Number of failed cudaMalloc calls that result in a cache flush and retry. | +ββββββββ+βββββββββββββββββββββββββββββ-+
- Parameters:
memory_keys (dict[str, str], optional) β A dict specifying memory statistics to log. Keys are the names of memory statistics to log from
torch.cuda.memory_stats()
, and values are the names they will be logged under. If not provided, the above statistics are logged. Defaults to None.- Returns:
Memory metrics dictionary.
- bridge.training.utils.train_utils.report_l2_norm_grad(
- model: list[megatron.core.transformer.module.MegatronModule],
Computes and logs the L2 norm of gradients. L2 norms are calculated after the reduction of gradients across GPUs. This function iterates over the parameters of the model and may cause a reduction in throughput while training large models. In order to ensure the correctness of the norm, this function should be called after gradient unscaling in cases where gradients are scaled. The following statistics are recorded: +ββββββββββββββββ+ββββββββββββββββββ+ | Key | Logged data | +===============================================+=====================================================+ | | L2 norm of the gradients of all parameters in | |
l2_norm/grad/global
| the model. | +ββββββββββββββββ+ββββββββββββββββββ+ | | Layer-wise L2 norms | |l2_norm/grad/LAYER_NAME
| | | | | +ββββββββββββββββ+ββββββββββββββββββ+- Parameters:
model (Union[MegatronModule, list[MegatronModule]]) β megatron model state.
- Returns:
Dictionary with L2 norms for each layer.
- bridge.training.utils.train_utils.report_runtime(
- train_state: megatron.bridge.training.state.TrainState,
- start_time: int,
- seq_length: int,
- train_iters: int,
- time_unit: str = 'seconds',
Estimates total training time. The training time is computed by taking the time elapsed for the current duration and multiplying out to the full extended length of the training run. This metric provides a best attempt estimate. This estimate may be inaccurate if throughput changes through training or other significant changes are made to the model or dataloader. The following statistics are recorded: +ββββββββββ+ββββββββββ-+ | Key | Logged data | +=============================+===============================+ |
time/remaining_estimate
| Estimated time to completion | +ββββββββββ+ββββββββββ-+ |time/tokens
| Number of consumed tokens | +ββββββββββ+ββββββββββ-+ |time/samples
| Number of consumed samples | +ββββββββββ+ββββββββββ-+ |time/batches
| Number of consumed batches | +ββββββββββ+ββββββββββ-+ |time/total
| Total training time | +ββββββββββ+ββββββββββ-+- Parameters:
train_state
start_time (int) β time when training was started.
seq_length (int) β model sequence length.
train_iters (int) β number of train iters to be done per training.
time_unit (str, optional) β Time unit to use for
time
logging. Can be one of βsecondsβ, βminutesβ, βhoursβ, or βdaysβ. Defaults to βhoursβ.
- Param :
- bridge.training.utils.train_utils.report_throughput(
- train_config: megatron.bridge.training.config.TrainingConfig,
- iteration: int,
- seq_length: int,
- history_wct: list,
- window_size: int,
Logs the training throughput and utilization. The training throughput is logged on the event once we have reached the
window_size
threshold. The following statistics are recorded: +ββββββββββββ-+ββββββββββββββββββββ+ | Key | Logged data | +=====================================+===========================================================+ | | Rolling average (overwindow_size
most recent | |throughput/batches_per_sec
| batches) of the number of batches processed per second. | | | | +ββββββββββββ-+ββββββββββββββββββββ+ | | Rolling average (overwindow_size
most recent | |throughput/samples_per_sec
| batches) of the number of samples processed per second. | | | | +ββββββββββββ-+ββββββββββββββββββββ+ | | Rolling average (overwindow_size
most recent | |throughput/tokens_per_sec
| batches) of the number of tokens processed per second. | | | Only logged if dataspec returns tokens per batch. | +ββββββββββββ-+ββββββββββββββββββββ+ |throughput/device/batches_per_sec
|throughput/batches_per_sec
divided by world size. | +ββββββββββββ-+ββββββββββββββββββββ+ |throughput/device/samples_per_sec
|throughput/samples_per_sec
divided by world size. | +ββββββββββββ-+ββββββββββββββββββββ+ | |throughput/tokens_per_sec
divided by world size. Only | |throughput/device/tokens_per_sec
| logged if dataspec returns tokens per batch. | | | | +ββββββββββββ-+ββββββββββββββββββββ+- Parameters:
train_config (TrainingConfig) β model train config.
iteration (int) β current train iteration.
seq_length (int) β model sequence length.
history_wct (list) β list of elapsed time per each iteration.
window_size (int, optional) β Number of batches to use for a rolling average of throughput.
- Returns:
Dictionary with throughput metrics.
- bridge.training.utils.train_utils.prepare_forward_step_func(
- forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
- state: megatron.bridge.training.state.GlobalState,
Convenience function to check and inject GlobalState in one call.
This combines needs_global_state_injection() and maybe_inject_state() for cleaner code. Call this once at the beginning of train() or evaluate() to prevent creating new partial objects every iteration.
Wrapping once is safe since:
functools.partial stores a reference to the state object, not a copy
When state.train_state.step or other fields change, the partial sees those changes
No staleness issues because GlobalState is mutable and passed by reference
Functor support:
Works with both regular functions (def forward_step(β¦)) and callable classes
For functors: inspect.signature() inspects the call method
For functors: partial(functor_instance, state) preserves functorβs internal state
Example: If functor has self.call_count, it still increments correctly
- Parameters:
forward_step_func β The original forward step function or functor
state β The GlobalState object to inject if needed
- Returns:
The wrapped function (if injection needed) or original function
- bridge.training.utils.train_utils.needs_global_state_injection(
- forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
Check if a forward step function needs GlobalState injection.
This function does the signature inspection once to determine if state should be injected. Itβs more efficient than repeated signature inspection in the training loop.
Detection logic:
First checks for GlobalState type annotation in any parameter
Falls back to checking if first parameter is named βstateβ or βglobal_stateβ
- Parameters:
forward_step_func β The forward step function to inspect.
- Returns:
True if GlobalState should be injected, False otherwise.
- bridge.training.utils.train_utils.maybe_inject_state(
- forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
- state: megatron.bridge.training.state.GlobalState,
- needs_injection: Optional[bool] = None,
Optionally inject GlobalState into forward_step functions that expect it.
Determines whether to inject state by inspecting function signature:
First checks for GlobalState type annotation in any parameter
Falls back to checking if first parameter is named βstateβ
Otherwise assumes the function doesnβt expect state
Supported signatures:
(data_iterator, model) β no injection
(data_iterator, model, return_schedule_plan) β no injection
(state: GlobalState, data_iterator, model) β inject state
(state: GlobalState, data_iterator, model, return_schedule_plan) β inject state
(state, data_iterator, model) β inject state (fallback to name-based detection)
- Parameters:
forward_step_func β The original forward step function.
state β The GlobalState object to potentially inject.
needs_injection β Whether injection is needed (optional, will be inspected if None). Pass this to avoid repeated signature inspection in training loops.
- Returns:
The original function or a partial function with GlobalState injected.