bridge.training.utils.flop_utils#

Module Contents#

Functions#

vit_flops

Calculate FLOPs for a Vision Transformer (ViT) encoder + patch merger.

num_floating_point_operations

Return the number of floating point operations.

Data#

API#

bridge.training.utils.flop_utils._lora_seq_stats_cache: dict#

None

bridge.training.utils.flop_utils.vit_flops(
cfg: megatron.bridge.training.config.ConfigContainer,
batch_size: int,
num_patches: int,
)#

Calculate FLOPs for a Vision Transformer (ViT) encoder + patch merger.

Includes:

  • ViT transformer layers (bidirectional full attention, not causal)

  • Patch merger (spatial merge + MLP projection to LLM hidden size)

Parameters:
  • cfg – Configuration container. ViT hyper-parameters are read from cfg.model.vision_config (depth, hidden_size, num_heads, intermediate_size, spatial_merge_size, out_hidden_size). Passing the whole config keeps the public signature stable as the list of required ViT attributes grows.

  • batch_size – Batch size.

  • num_patches – Per-image number of vision patches (before spatial merge). Callers that track the total patch count across the batch should divide by batch_size before invoking, because ViT attention is per-image (not cross-image) and scales quadratically with the per-image patch count.

Returns:

Total training FLOPs (forward * 3 for fwd+bwd). Returns 0 when no vision_config is attached or num_patches is non-positive.

bridge.training.utils.flop_utils.num_floating_point_operations(
cfg: megatron.bridge.training.config.ConfigContainer,
batch_size: int = 1,
seqlen_sum: int | None = None,
seqlen_squared_sum: int | None = None,
num_vision_patches: int = 0,
)#

Return the number of floating point operations.

Parameters:
  • cfg – Configuration container.

  • batch_size – Batch size.

  • seqlen_sum – Sum of actual sequence lengths across the batch (batch_size * actual_seq_length). When provided, overrides cfg.model.seq_length for more accurate FLOPS estimation with dynamic-length sequences (e.g., VLM with dynamic padding).

  • seqlen_squared_sum – Sum of squared sequence lengths across the batch (sum_i actual_seq_length_i^2). Used for attention core FLOPS which scale quadratically with sequence length; when omitted, falls back to batch_size * effective_seq_length^2 so the result matches the legacy constant-length estimate.

  • num_vision_patches – Total number of vision patches in the batch (before spatial merge). Used to compute ViT encoder FLOPS.