bridge.training.utils.train_utils#

Module Contents#

Functions#

param_is_not_shared

Check if a parameter is marked as not shared.

calc_params_l2_norm

Calculate the L2 norm of model parameters across all GPUs.

calc_dtensor_params_l2_norm

Calculate l2 norm of DTensor parameters.

reduce_max_stat_across_model_parallel_group

Calculates the max of a stat across the model parallel group.

logical_and_across_model_parallel_group

Performs a logical AND operation across the model parallel group.

training_log

Log training stats (losses, learning rate, timings, etc.).

report_memory

Report current and peak GPU memory usage for the current rank.

maybe_inject_state

Optionally inject GlobalState into a 4-arg forward_step function.

check_forward_step_func_num_args

Check if the forward step function has a supported number of arguments.

API#

bridge.training.utils.train_utils.param_is_not_shared(param: torch.nn.Parameter) bool#

Check if a parameter is marked as not shared.

Parameters:

param (torch.nn.Parameter) – The parameter to check.

Returns:

True if the parameter does not have a β€˜shared’ attribute or if param.shared is False.

Return type:

bool

bridge.training.utils.train_utils.calc_params_l2_norm(
model: Union[megatron.core.transformer.module.MegatronModule, list[megatron.core.transformer.module.MegatronModule]],
model_config: Any,
use_megatron_fsdp: bool = False,
force_create_fp32_copy: bool = False,
) float#

Calculate the L2 norm of model parameters across all GPUs.

Handles parameter sharding (DP, TP, PP, EP) and different parameter types (dense, MoE, sharded main params).

Parameters:
  • model (Union[torch.nn.Module, list[torch.nn.Module]]) – The model or list of model chunks.

  • model_config – The model configuration object.

  • force_create_fp32_copy (bool, optional) – If True, always creates an FP32 copy for norm calculation, ignoring potential main_param attributes. Defaults to False.

Returns:

The L2 norm of all parameters.

Return type:

float

bridge.training.utils.train_utils.calc_dtensor_params_l2_norm(params)#

Calculate l2 norm of DTensor parameters.

bridge.training.utils.train_utils.reduce_max_stat_across_model_parallel_group(
stat: Optional[float],
) Optional[float]#

Calculates the max of a stat across the model parallel group.

Handles cases where some ranks might have the stat as None (e.g., grad norm on ranks without an optimizer).

Parameters:

stat (float) – The statistic value (or None) on the current rank.

Returns:

The maximum value of the statistic across the model parallel group, or None if all ranks had None.

Return type:

float

bridge.training.utils.train_utils.logical_and_across_model_parallel_group(input: bool) bool#

Performs a logical AND operation across the model parallel group.

Parameters:

input (bool) – The boolean value on the current rank.

Returns:

The result of the logical AND across all ranks in the group.

Return type:

bool

bridge.training.utils.train_utils.training_log(
loss_dict: dict[str, torch.Tensor],
total_loss_dict: dict[str, Any],
learning_rate: Optional[float],
decoupled_learning_rate: Optional[float],
loss_scale: float,
report_memory_flag: bool,
skipped_iter: int,
grad_norm: Optional[float],
params_norm: Optional[float],
num_zeros_in_grad: Optional[int],
config: megatron.bridge.training.config.ConfigContainer,
global_state: megatron.bridge.training.state.GlobalState,
) bool#

Log training stats (losses, learning rate, timings, etc.).

Aggregates losses, logs metrics to TensorBoard and WandB (if enabled), and prints a formatted log string to the console on the last rank.

Parameters:
  • loss_dict (dict[str, torch.Tensor]) – Dictionary of losses for the current step.

  • total_loss_dict (dict[str, Any]) – Dictionary to accumulate losses and stats across logging intervals.

  • learning_rate (Optional[float]) – Current learning rate.

  • decoupled_learning_rate (Optional[float]) – Current decoupled learning rate (if used).

  • loss_scale (float) – Current loss scale value.

  • report_memory_flag (bool) – Flag to indicate if memory usage should be reported.

  • skipped_iter (int) – 1 if the iteration was skipped, 0 otherwise.

  • grad_norm (Optional[float]) – Gradient norm if computed, else None.

  • params_norm (Optional[float]) – Parameter L2 norm if computed, else None.

  • num_zeros_in_grad (Optional[int]) – Number of zeros in gradient if computed, else None.

  • config – The main configuration container.

  • global_state – The global training state.

Returns:

The updated report_memory_flag.

Return type:

bool

bridge.training.utils.train_utils.report_memory(name: str) None#

Report current and peak GPU memory usage for the current rank.

Parameters:

name (str) – A name to include in the output message (e.g., stage of training).

bridge.training.utils.train_utils.maybe_inject_state(
forward_step_func: Callable,
state: megatron.bridge.training.state.GlobalState,
num_fw_args: Optional[int] = None,
) Callable#

Optionally inject GlobalState into a 4-arg forward_step function.

  • If the function has 4 parameters (state, data_iterator, model, return_schedule_plan), bind the provided state via functools.partial to produce a callable that accepts (data_iterator, model, return_schedule_plan).

  • If the function already has 3 parameters (data_iterator, model, return_schedule_plan) or 2 parameters (data_iterator, model), return it unchanged.

Parameters:
  • forward_step_func – The original forward step function.

  • state – The GlobalState object to potentially inject.

  • num_fw_args – The number of arguments the forward_step_func expects (optional, will be inspected if None).

Returns:

The original function or a partial function with GlobalState injected.

bridge.training.utils.train_utils.check_forward_step_func_num_args(
forward_step_func: Callable,
) int#

Check if the forward step function has a supported number of arguments.

Currently supports 2, 3, or 4 arguments:

  • func(data_iterator, model)

  • func(data_iterator, model, return_schedule_plan: bool = False) # state pre-bound via partial

  • func(state, data_iterator, model, return_schedule_plan: bool = False)

Parameters:

forward_step_func – The function to check.

Returns:

The number of arguments the function takes.

Raises:

AssertionError – If the function does not have 2 or 4 arguments.