bridge.training.utils.train_utils#

Module Contents#

Functions#

param_is_not_shared

Check if a parameter is marked as not shared.

calc_params_l2_norm

Calculate the L2 norm of model parameters across all GPUs.

calc_dtensor_params_l2_norm

Calculate l2 norm of DTensor parameters.

reduce_max_stat_across_model_parallel_group

Calculates the max of a stat across the model parallel group.

logical_and_across_model_parallel_group

Performs a logical AND operation across the model parallel group.

training_log

Log training stats (losses, learning rate, timings, etc.).

report_memory

Report current and peak GPU memory usage for the current rank.

prepare_forward_step_func

Convenience function to check and inject GlobalState in one call.

needs_global_state_injection

Check if a forward step function needs GlobalState injection.

maybe_inject_state

Optionally inject GlobalState into forward_step functions that expect it.

API#

bridge.training.utils.train_utils.param_is_not_shared(param: torch.nn.Parameter) bool#

Check if a parameter is marked as not shared.

Parameters:

param (torch.nn.Parameter) – The parameter to check.

Returns:

True if the parameter does not have a ‘shared’ attribute or if param.shared is False.

Return type:

bool

bridge.training.utils.train_utils.calc_params_l2_norm(
model: Union[megatron.core.transformer.module.MegatronModule, list[megatron.core.transformer.module.MegatronModule]],
model_config: Any,
use_megatron_fsdp: bool = False,
force_create_fp32_copy: bool = False,
) float#

Calculate the L2 norm of model parameters across all GPUs.

Handles parameter sharding (DP, TP, PP, EP) and different parameter types (dense, MoE, sharded main params).

Parameters:
  • model (Union[torch.nn.Module, list[torch.nn.Module]]) – The model or list of model chunks.

  • model_config – The model configuration object.

  • force_create_fp32_copy (bool, optional) – If True, always creates an FP32 copy for norm calculation, ignoring potential main_param attributes. Defaults to False.

Returns:

The L2 norm of all parameters.

Return type:

float

bridge.training.utils.train_utils.calc_dtensor_params_l2_norm(params)#

Calculate l2 norm of DTensor parameters.

bridge.training.utils.train_utils.reduce_max_stat_across_model_parallel_group(
stat: Optional[float],
) Optional[float]#

Calculates the max of a stat across the model parallel group.

Handles cases where some ranks might have the stat as None (e.g., grad norm on ranks without an optimizer).

Parameters:

stat (float) – The statistic value (or None) on the current rank.

Returns:

The maximum value of the statistic across the model parallel group, or None if all ranks had None.

Return type:

float

bridge.training.utils.train_utils.logical_and_across_model_parallel_group(input: bool) bool#

Performs a logical AND operation across the model parallel group.

Parameters:

input (bool) – The boolean value on the current rank.

Returns:

The result of the logical AND across all ranks in the group.

Return type:

bool

bridge.training.utils.train_utils.training_log(
loss_dict: dict[str, torch.Tensor],
total_loss_dict: dict[str, Any],
learning_rate: Optional[float],
decoupled_learning_rate: Optional[float],
loss_scale: float,
report_memory_flag: bool,
skipped_iter: int,
grad_norm: Optional[float],
params_norm: Optional[float],
num_zeros_in_grad: Optional[int],
config: megatron.bridge.training.config.ConfigContainer,
global_state: megatron.bridge.training.state.GlobalState,
) bool#

Log training stats (losses, learning rate, timings, etc.).

Aggregates losses, logs metrics to TensorBoard and WandB (if enabled), and prints a formatted log string to the console on the last rank.

Parameters:
  • loss_dict (dict[str, torch.Tensor]) – Dictionary of losses for the current step.

  • total_loss_dict (dict[str, Any]) – Dictionary to accumulate losses and stats across logging intervals.

  • learning_rate (Optional[float]) – Current learning rate.

  • decoupled_learning_rate (Optional[float]) – Current decoupled learning rate (if used).

  • loss_scale (float) – Current loss scale value.

  • report_memory_flag (bool) – Flag to indicate if memory usage should be reported.

  • skipped_iter (int) – 1 if the iteration was skipped, 0 otherwise.

  • grad_norm (Optional[float]) – Gradient norm if computed, else None.

  • params_norm (Optional[float]) – Parameter L2 norm if computed, else None.

  • num_zeros_in_grad (Optional[int]) – Number of zeros in gradient if computed, else None.

  • config – The main configuration container.

  • global_state – The global training state.

Returns:

The updated report_memory_flag.

Return type:

bool

bridge.training.utils.train_utils.report_memory(name: str) None#

Report current and peak GPU memory usage for the current rank.

Parameters:

name (str) – A name to include in the output message (e.g., stage of training).

bridge.training.utils.train_utils.prepare_forward_step_func(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
state: megatron.bridge.training.state.GlobalState,
) megatron.bridge.training.forward_step_func_types.ForwardStepCallable#

Convenience function to check and inject GlobalState in one call.

This combines needs_global_state_injection() and maybe_inject_state() for cleaner code. Call this once at the beginning of train() or evaluate() to prevent creating new partial objects every iteration.

Wrapping once is safe since:

  • functools.partial stores a reference to the state object, not a copy

  • When state.train_state.step or other fields change, the partial sees those changes

  • No staleness issues because GlobalState is mutable and passed by reference

Functor support:

  • Works with both regular functions (def forward_step(…)) and callable classes

  • For functors: inspect.signature() inspects the call method

  • For functors: partial(functor_instance, state) preserves functor’s internal state

  • Example: If functor has self.call_count, it still increments correctly

Parameters:
  • forward_step_func – The original forward step function or functor

  • state – The GlobalState object to inject if needed

Returns:

The wrapped function (if injection needed) or original function

bridge.training.utils.train_utils.needs_global_state_injection(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
) bool#

Check if a forward step function needs GlobalState injection.

This function does the signature inspection once to determine if state should be injected. It’s more efficient than repeated signature inspection in the training loop.

Detection logic:

  1. First checks for GlobalState type annotation in any parameter

  2. Falls back to checking if first parameter is named ‘state’ or ‘global_state’

Parameters:

forward_step_func – The forward step function to inspect.

Returns:

True if GlobalState should be injected, False otherwise.

bridge.training.utils.train_utils.maybe_inject_state(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
state: megatron.bridge.training.state.GlobalState,
needs_injection: Optional[bool] = None,
) megatron.bridge.training.forward_step_func_types.ForwardStepCallable#

Optionally inject GlobalState into forward_step functions that expect it.

Determines whether to inject state by inspecting function signature:

  1. First checks for GlobalState type annotation in any parameter

  2. Falls back to checking if first parameter is named ‘state’

  3. Otherwise assumes the function doesn’t expect state

Supported signatures:

  • (data_iterator, model) → no injection

  • (data_iterator, model, return_schedule_plan) → no injection

  • (state: GlobalState, data_iterator, model) → inject state

  • (state: GlobalState, data_iterator, model, return_schedule_plan) → inject state

  • (state, data_iterator, model) → inject state (fallback to name-based detection)

Parameters:
  • forward_step_func – The original forward step function.

  • state – The GlobalState object to potentially inject.

  • needs_injection – Whether injection is needed (optional, will be inspected if None). Pass this to avoid repeated signature inspection in training loops.

Returns:

The original function or a partial function with GlobalState injected.