bridge.training.utils.train_utils
#
Module Contents#
Functions#
Check if a parameter is marked as not shared. |
|
Calculate the L2 norm of model parameters across all GPUs. |
|
Calculates the max of a stat across the model parallel group. |
|
Performs a logical AND operation across the model parallel group. |
|
Log training stats (losses, learning rate, timings, etc.). |
|
Report current and peak GPU memory usage for the current rank. |
|
Track and log Mixture of Experts (MoE) specific metrics. |
|
Clear the MoE auxiliary loss tracker in the parallel state. |
|
Reduce the MoE auxiliary losses across pipeline and specified reduction groups. |
|
Optionally inject the GlobalState object into the forward step function. |
|
Check if the forward step function has a supported number of arguments. |
API#
Check if a parameter is marked as not shared.
- Parameters:
param (torch.nn.Parameter) β The parameter to check.
- Returns:
True if the parameter does not have a βsharedβ attribute or if param.shared is False.
- Return type:
bool
- bridge.training.utils.train_utils.calc_params_l2_norm(
- model: Union[megatron.core.transformer.module.MegatronModule, list[megatron.core.transformer.module.MegatronModule]],
- model_config: Any,
- force_create_fp32_copy: bool = False,
Calculate the L2 norm of model parameters across all GPUs.
Handles parameter sharding (DP, TP, PP, EP) and different parameter types (dense, MoE, sharded main params).
- Parameters:
model (Union[torch.nn.Module, list[torch.nn.Module]]) β The model or list of model chunks.
model_config β The model configuration object.
force_create_fp32_copy (bool, optional) β If True, always creates an FP32 copy for norm calculation, ignoring potential
main_param
attributes. Defaults to False.
- Returns:
The L2 norm of all parameters.
- Return type:
float
- bridge.training.utils.train_utils.reduce_max_stat_across_model_parallel_group(
- stat: Optional[float],
Calculates the max of a stat across the model parallel group.
Handles cases where some ranks might have the stat as None (e.g., grad norm on ranks without an optimizer).
- Parameters:
stat (float) β The statistic value (or None) on the current rank.
- Returns:
The maximum value of the statistic across the model parallel group, or None if all ranks had None.
- Return type:
float
- bridge.training.utils.train_utils.logical_and_across_model_parallel_group(input: bool) bool #
Performs a logical AND operation across the model parallel group.
- Parameters:
input (bool) β The boolean value on the current rank.
- Returns:
The result of the logical AND across all ranks in the group.
- Return type:
bool
- bridge.training.utils.train_utils.training_log(
- loss_dict: dict[str, torch.Tensor],
- total_loss_dict: dict[str, Any],
- learning_rate: Optional[float],
- decoupled_learning_rate: Optional[float],
- loss_scale: float,
- report_memory_flag: bool,
- skipped_iter: int,
- grad_norm: Optional[float],
- params_norm: Optional[float],
- num_zeros_in_grad: Optional[int],
- config: megatron.bridge.training.config.ConfigContainer,
- global_state: megatron.bridge.training.state.GlobalState,
Log training stats (losses, learning rate, timings, etc.).
Aggregates losses, logs metrics to TensorBoard and WandB (if enabled), and prints a formatted log string to the console on the last rank.
- Parameters:
loss_dict (dict[str, torch.Tensor]) β Dictionary of losses for the current step.
total_loss_dict (dict[str, Any]) β Dictionary to accumulate losses and stats across logging intervals.
learning_rate (Optional[float]) β Current learning rate.
decoupled_learning_rate (Optional[float]) β Current decoupled learning rate (if used).
loss_scale (float) β Current loss scale value.
report_memory_flag (bool) β Flag to indicate if memory usage should be reported.
skipped_iter (int) β 1 if the iteration was skipped, 0 otherwise.
grad_norm (Optional[float]) β Gradient norm if computed, else None.
params_norm (Optional[float]) β Parameter L2 norm if computed, else None.
num_zeros_in_grad (Optional[int]) β Number of zeros in gradient if computed, else None.
config β The main configuration container.
global_state β The global training state.
- Returns:
The updated report_memory_flag.
- Return type:
bool
- bridge.training.utils.train_utils.report_memory(name: str) None #
Report current and peak GPU memory usage for the current rank.
- Parameters:
name (str) β A name to include in the output message (e.g., stage of training).
- bridge.training.utils.train_utils.track_moe_metrics(
- loss_scale: float,
- iteration: int,
- tb_logger: Any,
- wandb_logger: Optional[Any] = None,
- total_loss_dict: Optional[dict] = None,
- per_layer_logging: bool = False,
Track and log Mixture of Experts (MoE) specific metrics.
Reduces auxiliary losses across ranks and logs them to TensorBoard and WandB.
- Parameters:
loss_scale (float) β The current loss scale.
iteration (int) β The current training iteration.
tb_logger β The TensorBoard logger instance.
wandb_logger β The WandB logger instance (optional).
total_loss_dict (Optional[dict]) β Dictionary to accumulate total losses (optional).
per_layer_logging (bool) β If True, logs metrics for each MoE layer individually.
- bridge.training.utils.train_utils.clear_aux_losses_tracker()#
Clear the MoE auxiliary loss tracker in the parallel state.
- bridge.training.utils.train_utils.reduce_aux_losses_tracker_across_ranks()#
Reduce the MoE auxiliary losses across pipeline and specified reduction groups.
- bridge.training.utils.train_utils.maybe_inject_state(
- forward_step_func: Callable,
- state: megatron.bridge.training.state.GlobalState,
- num_fw_args: Optional[int] = None,
Optionally inject the GlobalState object into the forward step function.
Checks the number of arguments of
forward_step_func
. If it expects 3 arguments, it assumes the first argument is the GlobalState and returns a partial function with the state injected.- Parameters:
forward_step_func β The original forward step function.
state β The GlobalState object to potentially inject.
num_fw_args β The number of arguments the forward_step_func expects (optional, will be inspected if None).
- Returns:
The original function or a partial function with GlobalState injected.
- bridge.training.utils.train_utils.check_forward_step_func_num_args(
- forward_step_func: Callable,
Check if the forward step function has a supported number of arguments.
Currently supports 2 or 3 arguments:
func(data_iterator, model)
func(state, data_iterator, model)
- Parameters:
forward_step_func β The function to check.
- Returns:
The number of arguments the function takes.
- Raises:
AssertionError β If the function does not have 2 or 3 arguments.