bridge.training.utils.train_utils#

Module Contents#

Classes#

LinearForLastLayer

Final replicated projection head compatible with Megatron output-layer calls.

_MoeMetricFanoutWriter

SummaryWriter-shaped adapter that fans add_scalar to MLFlow / Comet.

Functions#

create_value_head_hook

Create a pre-wrap hook that replaces the final pipeline stage output head.

make_value_model

Create a value-head hook compatible with existing external trainer code.

freeze_moe_router

Freeze MoE router and shared-expert gate parameters in model chunks.

_ensure_model_list

_freeze_parameter_if_present

_register_linear_for_last_layer_mapping

start_memory_history_recording

Enable the CUDA caching allocator trace so memory snapshots contain history.

param_is_not_shared

Check if a parameter is marked as not shared.

calc_params_l2_norm

Calculate the L2 norm of model parameters across all GPUs.

calc_dtensor_params_l2_norm

Calculate l2 norm of DTensor parameters.

reduce_max_stat_across_model_parallel_group

Calculates the max of a stat across the model parallel group.

logical_and_across_model_parallel_group

Performs a logical AND operation across the model parallel group.

reduce_max_memory_across_pp_group

Reduce per-rank memory metrics across the PP group with MAX.

_build_moe_metric_writer

Return a writer suitable for MCore’s MoE/MTP metric helpers.

training_log

Log training stats (losses, learning rate, timings, etc.).

report_memory

Logs the memory usage of the model. This metric calls the torch memory stats API for CUDA and reports different memory statistics. The following statistics are recorded: +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | Statistic | Description | +========================+========================================================================================+ | current_allocated_mem | Current amount of allocated memory in gigabytes. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | current_active_mem | Current amount of active memory in gigabytes at the time of recording. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | current_inactive_mem | Current amount of inactive, non-releaseable memory in gigabytes. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | current_reserved_mem | Current amount of reserved memory in gigabytes at the time of recording. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | peak_allocated_mem | Peak amount of allocated memory in gigabytes. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | peak_active_mem | Peak amount of active memory in gigabytes at the time of recording. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | peak_inactive_mem | Peak amount of inactive, non-releaseable memory in gigabytes at the time of recording. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | peak_reserved_mem | Peak amount of reserved memory in gigabytes at the time of recording. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | alloc_retries | Number of failed cudaMalloc calls that result in a cache flush and retry. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+

report_l2_norm_grad

Computes and logs the L2 norm of gradients. L2 norms are calculated after the reduction of gradients across GPUs. This function iterates over the parameters of the model and may cause a reduction in throughput while training large models. In order to ensure the correctness of the norm, this function should be called after gradient unscaling in cases where gradients are scaled. The following statistics are recorded: +———————————————–+—————————————————–+ | Key | Logged data | +===============================================+=====================================================+ | | L2 norm of the gradients of all parameters in | | l2_norm/grad/global | the model. | +———————————————–+—————————————————–+ | | Layer-wise L2 norms | | l2_norm/grad/LAYER_NAME | | | | | +———————————————–+—————————————————–+

report_runtime

Estimates total training time. The training time is computed by taking the time elapsed for the current duration and multiplying out to the full extended length of the training run. This metric provides a best attempt estimate. This estimate may be inaccurate if throughput changes through training or other significant changes are made to the model or dataloader. The following statistics are recorded: +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | Key | Logged data | +=============================+===============================+ | time/remaining_estimate | Estimated time to completion | +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | time/tokens | Number of consumed tokens | +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | time/samples | Number of consumed samples | +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | time/batches | Number of consumed batches | +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | time/total | Total training time | +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+

report_throughput

Logs the training throughput and utilization. The training throughput is logged on the event once we have reached the window_size threshold. The following statistics are recorded: +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | Key | Logged data | +=====================================+===========================================================+ | | Rolling average (over window_size most recent | | throughput/batches_per_sec | batches) of the number of batches processed per second. | | | | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | | Rolling average (over window_size most recent | | throughput/samples_per_sec | batches) of the number of samples processed per second. | | | | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | | Rolling average (over window_size most recent | | throughput/tokens_per_sec | batches) of the number of tokens processed per second. | | | Only logged if dataspec returns tokens per batch. | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | throughput/device/batches_per_sec | throughput/batches_per_sec divided by world size. | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | throughput/device/samples_per_sec | throughput/samples_per_sec divided by world size. | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | | throughput/tokens_per_sec divided by world size. Only | | throughput/device/tokens_per_sec | logged if dataspec returns tokens per batch. | | | | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+

prepare_forward_step_func

Convenience function to check and inject GlobalState in one call.

needs_global_state_injection

Check if a forward step function needs GlobalState injection.

maybe_inject_state

Optionally inject GlobalState into forward_step functions that expect it.

Data#

API#

bridge.training.utils.train_utils.logger#

β€˜getLogger(…)’

bridge.training.utils.train_utils.ModelList#

None

bridge.training.utils.train_utils.ModelHook#

None

class bridge.training.utils.train_utils.LinearForLastLayer(
input_size: int,
output_size: int,
sequence_parallel: bool,
)#

Bases: torch.nn.Linear

Final replicated projection head compatible with Megatron output-layer calls.

Megatron-Core output layers receive a few runtime-only arguments. This head accepts those arguments for call-site compatibility while using a standard replicated linear projection.

Initialization

Initialize a replicated final projection.

Parameters:
  • input_size – Hidden dimension of the transformer output.

  • output_size – Output dimension of the value/reward head.

  • sequence_parallel – Whether to gather sequence-parallel activations.

forward(
input_: torch.Tensor,
weight: torch.Tensor | None = None,
runtime_gather_output: bool | None = None,
) tuple[torch.Tensor, None]#

Run the final projection and return Megatron-style (output, bias).

bridge.training.utils.train_utils.create_value_head_hook(
hidden_size: int,
sequence_parallel: bool,
output_size: int = 1,
) bridge.training.utils.train_utils.ModelHook#

Create a pre-wrap hook that replaces the final pipeline stage output head.

Parameters:
  • hidden_size – Hidden dimension of the transformer output.

  • sequence_parallel – Whether the model uses sequence parallelism.

  • output_size – Number of outputs produced by the final head.

Returns:

A model hook suitable for external trainer provider construction.

bridge.training.utils.train_utils.make_value_model(
hidden_size: int,
sequence_parallel: bool,
) bridge.training.utils.train_utils.ModelHook#

Create a value-head hook compatible with existing external trainer code.

bridge.training.utils.train_utils.freeze_moe_router(
model: bridge.training.utils.train_utils.ModelList | megatron.core.transformer.module.MegatronModule,
) bridge.training.utils.train_utils.ModelList#

Freeze MoE router and shared-expert gate parameters in model chunks.

Parameters:

model – Single Megatron module or list of virtual-pipeline model chunks.

Returns:

The normalized model chunk list with router parameters frozen in place.

bridge.training.utils.train_utils._ensure_model_list(
model: bridge.training.utils.train_utils.ModelList | megatron.core.transformer.module.MegatronModule,
) bridge.training.utils.train_utils.ModelList#
bridge.training.utils.train_utils._freeze_parameter_if_present(module: object, name: str) None#
bridge.training.utils.train_utils._register_linear_for_last_layer_mapping() None#
bridge.training.utils.train_utils.start_memory_history_recording(
profiling: megatron.bridge.training.config.ProfilingConfig | None,
) None#

Enable the CUDA caching allocator trace so memory snapshots contain history.

torch.cuda.memory._snapshot() only includes allocation/free events and Python stack context after _record_memory_history() has been enabled. Without this call, dumped snapshots contain only the current live allocations β€” no timeline, no call sites.

Must be invoked before model construction so every tensor allocation is captured. Guarded by profile_ranks so only ranks that will dump a snapshot pay the recording overhead.

bridge.training.utils.train_utils.MEMORY_KEYS: dict[str, str]#

None

bridge.training.utils.train_utils.param_is_not_shared(param: torch.nn.Parameter) bool#

Check if a parameter is marked as not shared.

Parameters:

param (torch.nn.Parameter) – The parameter to check.

Returns:

True if the parameter does not have a β€˜shared’ attribute or if param.shared is False.

Return type:

bool

bridge.training.utils.train_utils.calc_params_l2_norm(
model: Union[megatron.core.transformer.module.MegatronModule, list[megatron.core.transformer.module.MegatronModule]],
model_config: Any,
use_megatron_fsdp: bool = False,
force_create_fp32_copy: bool = False,
) float#

Calculate the L2 norm of model parameters across all GPUs.

Handles parameter sharding (DP, TP, PP, EP) and different parameter types (dense, MoE, sharded main params).

Parameters:
  • model (Union[torch.nn.Module, list[torch.nn.Module]]) – The model or list of model chunks.

  • model_config – The model configuration object.

  • force_create_fp32_copy (bool, optional) – If True, always creates an FP32 copy for norm calculation, ignoring potential main_param attributes. Defaults to False.

Returns:

The L2 norm of all parameters.

Return type:

float

bridge.training.utils.train_utils.calc_dtensor_params_l2_norm(params)#

Calculate l2 norm of DTensor parameters.

bridge.training.utils.train_utils.reduce_max_stat_across_model_parallel_group(
stat: Optional[float],
mp_group: torch.distributed.distributed_c10d.ProcessGroup,
) Optional[float]#

Calculates the max of a stat across the model parallel group.

Handles cases where some ranks might have the stat as None (e.g., grad norm on ranks without an optimizer).

Parameters:
  • stat (float) – The statistic value (or None) on the current rank.

  • mp_group – The process group to reduce across (typically pg_collection.mp).

Returns:

The maximum value of the statistic across the model parallel group, or None if all ranks had None.

Return type:

float

bridge.training.utils.train_utils.logical_and_across_model_parallel_group(
input: bool,
mp_group: torch.distributed.distributed_c10d.ProcessGroup,
) bool#

Performs a logical AND operation across the model parallel group.

Parameters:
  • input (bool) – The boolean value on the current rank.

  • mp_group – The process group to reduce across (typically pg_collection.mp).

Returns:

The result of the logical AND across all ranks in the group.

Return type:

bool

bridge.training.utils.train_utils.reduce_max_memory_across_pp_group(
memory_report: dict[str, Union[int, float]],
pp_group: torch.distributed.distributed_c10d.ProcessGroup,
) dict[str, Union[int, float]]#

Reduce per-rank memory metrics across the PP group with MAX.

With pipeline parallelism, peak GPU memory is typically dominated by the first PP stage (activation buildup). The TensorBoard / W&B / MLFlow / Comet writers, however, only initialize on the last rank (world_size - 1), so without aggregation the logged values reflect only the last PP stage and under-report true peak headroom.

This helper performs a single bulk all-reduce with MAX over the PP group so that the writer rank emits the per-metric peak across the pipeline. Counter-style integer keys (e.g. alloc_retries) are preserved as int so dashboards continue to render them correctly.

No-op when distributed is uninitialized, the PP group has a single rank, or the report is empty.

Parameters:
  • memory_report – Mapping of metric name to per-rank value.

  • pp_group – The pipeline-parallel process group to reduce across.

Returns:

A new dict with values replaced by the per-metric MAX across the PP group, or the input report unchanged when no reduction is needed.

class bridge.training.utils.train_utils._MoeMetricFanoutWriter(
tb_writer: Optional[Any],
comet_logger: Optional[Any],
mlflow_logger: Optional[Any],
)#

SummaryWriter-shaped adapter that fans add_scalar to MLFlow / Comet.

MCore’s track_moe_metrics and track_mtp_metrics emit metrics through a TensorBoard writer.add_scalar(name, value, iteration) call (and a separate wandb_writer.log(...) call). They do not know about MLFlow or Comet, so those backends never see MoE / MTP metrics β€” see issue #2989.

Rather than fork MCore, this adapter wraps the real TB writer (or stands in for a missing one) and forwards every add_scalar to MLFlow and Comet using the same per-step value. W&B is unaffected β€” the underlying functions still receive wandb_writer directly so their dict-based per-layer logging stays untouched.

Tensors are sanitized with .item() before being handed to MLFlow / Comet, matching the float/int conversion the existing MoE TensorBoard path implicitly relies on (TB tolerates 0-d tensors; MLFlow / Comet do not).

Initialization

static _sanitize(value: Any) Any#

Convert 0-d torch tensors to Python scalars; pass other values through.

MLFlow / Comet client APIs reject torch tensors silently or raise; the existing TB call accepts them. Force a scalar so all sinks behave.

add_scalar(name: str, value: Any, iteration: int) None#

Forward an add_scalar call to TB (if any), MLFlow, and Comet.

bridge.training.utils.train_utils._build_moe_metric_writer(
tb_writer: Optional[Any],
comet_logger: Optional[Any],
mlflow_logger: Optional[Any],
) Optional[Any]#

Return a writer suitable for MCore’s MoE/MTP metric helpers.

  • When neither MLFlow nor Comet is wired up, the real TB writer is returned unchanged (zero overhead, no behavior change).

  • When at least one of MLFlow / Comet is wired up, return a fanout adapter that forwards add_scalar to all configured backends. The adapter is returned even when the TB writer itself is None, which is required to surface MoE / MTP metrics in Comet / MLFlow on rank N-1 even if the user hasn’t enabled TensorBoard.

bridge.training.utils.train_utils.training_log(
loss_dict: dict[str, torch.Tensor],
total_loss_dict: dict[str, Any],
learning_rate: Optional[float],
decoupled_learning_rate: Optional[float],
loss_scale: float,
report_memory_flag: bool,
skipped_iter: int,
grad_norm: Optional[float],
params_norm: Optional[float],
num_zeros_in_grad: Optional[int],
config: megatron.bridge.training.config.ConfigContainer,
global_state: megatron.bridge.training.state.GlobalState,
history_wct: list,
model: list[megatron.core.transformer.module.MegatronModule],
pg_collection: Optional[Any] = None,
log_max_attention_logit: Optional[float] = None,
loaded_iteration: int = 0,
seq_length: Optional[int] = None,
) bool#

Log training stats (losses, learning rate, timings, etc.).

Aggregates losses, logs metrics to TensorBoard and WandB (if enabled), and prints a formatted log string to the console on the last rank.

Parameters:
  • loss_dict (dict[str, torch.Tensor]) – Dictionary of losses for the current step.

  • total_loss_dict (dict[str, Any]) – Dictionary to accumulate losses and stats across logging intervals.

  • learning_rate (Optional[float]) – Current learning rate.

  • decoupled_learning_rate (Optional[float]) – Current decoupled learning rate (if used).

  • loss_scale (float) – Current loss scale value.

  • report_memory_flag (bool) – Flag to indicate if memory usage should be reported.

  • skipped_iter (int) – 1 if the iteration was skipped, 0 otherwise.

  • grad_norm (Optional[float]) – Gradient norm if computed, else None.

  • params_norm (Optional[float]) – Parameter L2 norm if computed, else None.

  • num_zeros_in_grad (Optional[int]) – Number of zeros in gradient if computed, else None.

  • config – The main configuration container.

  • global_state – The global training state.

  • history_wct (list) – list of elapsed time per each iteration.

  • model (list[MegatronModule]) – megatron model state.

  • pg_collection (Optional[Any]) – ProcessGroupCollection to use for logging reductions. If None, falls back to extracting from model wrappers.

  • log_max_attention_logit (Optional[float]) – Maximum attention logit if available, None otherwise.

Returns:

The updated report_memory_flag.

Return type:

bool

bridge.training.utils.train_utils.report_memory(memory_keys: Optional[dict[str, str]]) dict#

Logs the memory usage of the model. This metric calls the torch memory stats API for CUDA and reports different memory statistics. The following statistics are recorded: +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | Statistic | Description | +========================+========================================================================================+ | current_allocated_mem | Current amount of allocated memory in gigabytes. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | current_active_mem | Current amount of active memory in gigabytes at the time of recording. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | current_inactive_mem | Current amount of inactive, non-releaseable memory in gigabytes. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | current_reserved_mem | Current amount of reserved memory in gigabytes at the time of recording. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | peak_allocated_mem | Peak amount of allocated memory in gigabytes. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | peak_active_mem | Peak amount of active memory in gigabytes at the time of recording. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | peak_inactive_mem | Peak amount of inactive, non-releaseable memory in gigabytes at the time of recording. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | peak_reserved_mem | Peak amount of reserved memory in gigabytes at the time of recording. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | alloc_retries | Number of failed cudaMalloc calls that result in a cache flush and retry. | +β€”β€”β€”β€”β€”β€”β€”β€”+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+

Parameters:

memory_keys (dict[str, str], optional) – A dict specifying memory statistics to log. Keys are the names of memory statistics to log from torch.cuda.memory_stats(), and values are the names they will be logged under. If not provided, the above statistics are logged. Defaults to None.

Returns:

Memory metrics dictionary.

bridge.training.utils.train_utils.report_l2_norm_grad(
model: list[megatron.core.transformer.module.MegatronModule],
) dict#

Computes and logs the L2 norm of gradients. L2 norms are calculated after the reduction of gradients across GPUs. This function iterates over the parameters of the model and may cause a reduction in throughput while training large models. In order to ensure the correctness of the norm, this function should be called after gradient unscaling in cases where gradients are scaled. The following statistics are recorded: +———————————————–+—————————————————–+ | Key | Logged data | +===============================================+=====================================================+ | | L2 norm of the gradients of all parameters in | | l2_norm/grad/global | the model. | +———————————————–+—————————————————–+ | | Layer-wise L2 norms | | l2_norm/grad/LAYER_NAME | | | | | +———————————————–+—————————————————–+

Parameters:

model (Union[MegatronModule, list[MegatronModule]]) – megatron model state.

Returns:

Dictionary with L2 norms for each layer.

bridge.training.utils.train_utils.report_runtime(
train_state: megatron.bridge.training.state.TrainState,
start_time: int,
seq_length: int,
train_iters: int,
time_unit: str = 'seconds',
) dict#

Estimates total training time. The training time is computed by taking the time elapsed for the current duration and multiplying out to the full extended length of the training run. This metric provides a best attempt estimate. This estimate may be inaccurate if throughput changes through training or other significant changes are made to the model or dataloader. The following statistics are recorded: +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | Key | Logged data | +=============================+===============================+ | time/remaining_estimate | Estimated time to completion | +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | time/tokens | Number of consumed tokens | +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | time/samples | Number of consumed samples | +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | time/batches | Number of consumed batches | +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+ | time/total | Total training time | +—————————–+β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+

Parameters:
  • train_state

  • start_time (int) – time when training was started.

  • seq_length (int) – model sequence length.

  • train_iters (int) – number of train iters to be done per training.

  • time_unit (str, optional) – Time unit to use for time logging. Can be one of β€˜seconds’, β€˜minutes’, β€˜hours’, or β€˜days’. Defaults to β€˜hours’.

Param :

bridge.training.utils.train_utils.report_throughput(
train_config: megatron.bridge.training.config.TrainingConfig,
iteration: int,
seq_length: int,
history_wct: list,
window_size: int,
) dict#

Logs the training throughput and utilization. The training throughput is logged on the event once we have reached the window_size threshold. The following statistics are recorded: +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | Key | Logged data | +=====================================+===========================================================+ | | Rolling average (over window_size most recent | | throughput/batches_per_sec | batches) of the number of batches processed per second. | | | | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | | Rolling average (over window_size most recent | | throughput/samples_per_sec | batches) of the number of samples processed per second. | | | | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | | Rolling average (over window_size most recent | | throughput/tokens_per_sec | batches) of the number of tokens processed per second. | | | Only logged if dataspec returns tokens per batch. | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | throughput/device/batches_per_sec | throughput/batches_per_sec divided by world size. | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | throughput/device/samples_per_sec | throughput/samples_per_sec divided by world size. | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+ | | throughput/tokens_per_sec divided by world size. Only | | throughput/device/tokens_per_sec | logged if dataspec returns tokens per batch. | | | | +β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”-+———————————————————–+

Parameters:
  • train_config (TrainingConfig) – model train config.

  • iteration (int) – current train iteration.

  • seq_length (int) – model sequence length.

  • history_wct (list) – list of elapsed time per each iteration.

  • window_size (int, optional) – Number of batches to use for a rolling average of throughput.

Returns:

Dictionary with throughput metrics.

bridge.training.utils.train_utils.prepare_forward_step_func(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
state: megatron.bridge.training.state.GlobalState,
) megatron.bridge.training.forward_step_func_types.ForwardStepCallable#

Convenience function to check and inject GlobalState in one call.

This combines needs_global_state_injection() and maybe_inject_state() for cleaner code. Call this once at the beginning of train() or evaluate() to prevent creating new partial objects every iteration.

Wrapping once is safe since:

  • functools.partial stores a reference to the state object, not a copy

  • When state.train_state.step or other fields change, the partial sees those changes

  • No staleness issues because GlobalState is mutable and passed by reference

Functor support:

  • Works with both regular functions (def forward_step(…)) and callable classes

  • For functors: inspect.signature() inspects the call method

  • For functors: partial(functor_instance, state) preserves functor’s internal state

  • Example: If functor has self.call_count, it still increments correctly

Parameters:
  • forward_step_func – The original forward step function or functor

  • state – The GlobalState object to inject if needed

Returns:

The wrapped function (if injection needed) or original function

bridge.training.utils.train_utils.needs_global_state_injection(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
) bool#

Check if a forward step function needs GlobalState injection.

This function does the signature inspection once to determine if state should be injected. It’s more efficient than repeated signature inspection in the training loop.

Detection logic:

  1. First checks for GlobalState type annotation in any parameter

  2. Falls back to checking if first parameter is named β€˜state’ or β€˜global_state’

Parameters:

forward_step_func – The forward step function to inspect.

Returns:

True if GlobalState should be injected, False otherwise.

bridge.training.utils.train_utils.maybe_inject_state(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
state: megatron.bridge.training.state.GlobalState,
needs_injection: Optional[bool] = None,
) megatron.bridge.training.forward_step_func_types.ForwardStepCallable#

Optionally inject GlobalState into forward_step functions that expect it.

Determines whether to inject state by inspecting function signature:

  1. First checks for GlobalState type annotation in any parameter

  2. Falls back to checking if first parameter is named β€˜state’

  3. Otherwise assumes the function doesn’t expect state

Supported signatures:

  • (data_iterator, model) β†’ no injection

  • (data_iterator, model, return_schedule_plan) β†’ no injection

  • (state: GlobalState, data_iterator, model) β†’ inject state

  • (state: GlobalState, data_iterator, model, return_schedule_plan) β†’ inject state

  • (state, data_iterator, model) β†’ inject state (fallback to name-based detection)

Parameters:
  • forward_step_func – The original forward step function.

  • state – The GlobalState object to potentially inject.

  • needs_injection – Whether injection is needed (optional, will be inspected if None). Pass this to avoid repeated signature inspection in training loops.

Returns:

The original function or a partial function with GlobalState injected.