bridge.training.utils.train_utils#

Module Contents#

Functions#

param_is_not_shared

Check if a parameter is marked as not shared.

calc_params_l2_norm

Calculate the L2 norm of model parameters across all GPUs.

reduce_max_stat_across_model_parallel_group

Calculates the max of a stat across the model parallel group.

logical_and_across_model_parallel_group

Performs a logical AND operation across the model parallel group.

training_log

Log training stats (losses, learning rate, timings, etc.).

report_memory

Report current and peak GPU memory usage for the current rank.

track_moe_metrics

Track and log Mixture of Experts (MoE) specific metrics.

clear_aux_losses_tracker

Clear the MoE auxiliary loss tracker in the parallel state.

reduce_aux_losses_tracker_across_ranks

Reduce the MoE auxiliary losses across pipeline and specified reduction groups.

maybe_inject_state

Optionally inject the GlobalState object into the forward step function.

check_forward_step_func_num_args

Check if the forward step function has a supported number of arguments.

API#

bridge.training.utils.train_utils.param_is_not_shared(param: torch.nn.Parameter) bool#

Check if a parameter is marked as not shared.

Parameters:

param (torch.nn.Parameter) – The parameter to check.

Returns:

True if the parameter does not have a β€˜shared’ attribute or if param.shared is False.

Return type:

bool

bridge.training.utils.train_utils.calc_params_l2_norm(
model: Union[megatron.core.transformer.module.MegatronModule, list[megatron.core.transformer.module.MegatronModule]],
model_config: Any,
force_create_fp32_copy: bool = False,
) float#

Calculate the L2 norm of model parameters across all GPUs.

Handles parameter sharding (DP, TP, PP, EP) and different parameter types (dense, MoE, sharded main params).

Parameters:
  • model (Union[torch.nn.Module, list[torch.nn.Module]]) – The model or list of model chunks.

  • model_config – The model configuration object.

  • force_create_fp32_copy (bool, optional) – If True, always creates an FP32 copy for norm calculation, ignoring potential main_param attributes. Defaults to False.

Returns:

The L2 norm of all parameters.

Return type:

float

bridge.training.utils.train_utils.reduce_max_stat_across_model_parallel_group(
stat: Optional[float],
) Optional[float]#

Calculates the max of a stat across the model parallel group.

Handles cases where some ranks might have the stat as None (e.g., grad norm on ranks without an optimizer).

Parameters:

stat (float) – The statistic value (or None) on the current rank.

Returns:

The maximum value of the statistic across the model parallel group, or None if all ranks had None.

Return type:

float

bridge.training.utils.train_utils.logical_and_across_model_parallel_group(input: bool) bool#

Performs a logical AND operation across the model parallel group.

Parameters:

input (bool) – The boolean value on the current rank.

Returns:

The result of the logical AND across all ranks in the group.

Return type:

bool

bridge.training.utils.train_utils.training_log(
loss_dict: dict[str, torch.Tensor],
total_loss_dict: dict[str, Any],
learning_rate: Optional[float],
decoupled_learning_rate: Optional[float],
loss_scale: float,
report_memory_flag: bool,
skipped_iter: int,
grad_norm: Optional[float],
params_norm: Optional[float],
num_zeros_in_grad: Optional[int],
config: megatron.bridge.training.config.ConfigContainer,
global_state: megatron.bridge.training.state.GlobalState,
) bool#

Log training stats (losses, learning rate, timings, etc.).

Aggregates losses, logs metrics to TensorBoard and WandB (if enabled), and prints a formatted log string to the console on the last rank.

Parameters:
  • loss_dict (dict[str, torch.Tensor]) – Dictionary of losses for the current step.

  • total_loss_dict (dict[str, Any]) – Dictionary to accumulate losses and stats across logging intervals.

  • learning_rate (Optional[float]) – Current learning rate.

  • decoupled_learning_rate (Optional[float]) – Current decoupled learning rate (if used).

  • loss_scale (float) – Current loss scale value.

  • report_memory_flag (bool) – Flag to indicate if memory usage should be reported.

  • skipped_iter (int) – 1 if the iteration was skipped, 0 otherwise.

  • grad_norm (Optional[float]) – Gradient norm if computed, else None.

  • params_norm (Optional[float]) – Parameter L2 norm if computed, else None.

  • num_zeros_in_grad (Optional[int]) – Number of zeros in gradient if computed, else None.

  • config – The main configuration container.

  • global_state – The global training state.

Returns:

The updated report_memory_flag.

Return type:

bool

bridge.training.utils.train_utils.report_memory(name: str) None#

Report current and peak GPU memory usage for the current rank.

Parameters:

name (str) – A name to include in the output message (e.g., stage of training).

bridge.training.utils.train_utils.track_moe_metrics(
loss_scale: float,
iteration: int,
tb_logger: Any,
wandb_logger: Optional[Any] = None,
total_loss_dict: Optional[dict] = None,
per_layer_logging: bool = False,
) None#

Track and log Mixture of Experts (MoE) specific metrics.

Reduces auxiliary losses across ranks and logs them to TensorBoard and WandB.

Parameters:
  • loss_scale (float) – The current loss scale.

  • iteration (int) – The current training iteration.

  • tb_logger – The TensorBoard logger instance.

  • wandb_logger – The WandB logger instance (optional).

  • total_loss_dict (Optional[dict]) – Dictionary to accumulate total losses (optional).

  • per_layer_logging (bool) – If True, logs metrics for each MoE layer individually.

bridge.training.utils.train_utils.clear_aux_losses_tracker()#

Clear the MoE auxiliary loss tracker in the parallel state.

bridge.training.utils.train_utils.reduce_aux_losses_tracker_across_ranks()#

Reduce the MoE auxiliary losses across pipeline and specified reduction groups.

bridge.training.utils.train_utils.maybe_inject_state(
forward_step_func: Callable,
state: megatron.bridge.training.state.GlobalState,
num_fw_args: Optional[int] = None,
) Callable#

Optionally inject the GlobalState object into the forward step function.

Checks the number of arguments of forward_step_func. If it expects 3 arguments, it assumes the first argument is the GlobalState and returns a partial function with the state injected.

Parameters:
  • forward_step_func – The original forward step function.

  • state – The GlobalState object to potentially inject.

  • num_fw_args – The number of arguments the forward_step_func expects (optional, will be inspected if None).

Returns:

The original function or a partial function with GlobalState injected.

bridge.training.utils.train_utils.check_forward_step_func_num_args(
forward_step_func: Callable,
) int#

Check if the forward step function has a supported number of arguments.

Currently supports 2 or 3 arguments:

  • func(data_iterator, model)

  • func(state, data_iterator, model)

Parameters:

forward_step_func – The function to check.

Returns:

The number of arguments the function takes.

Raises:

AssertionError – If the function does not have 2 or 3 arguments.