bridge.training.utils.flop_utils#
Module Contents#
Functions#
Calculate FLOPs for a Vision Transformer (ViT) encoder + patch merger. |
|
Return the number of floating point operations. |
Data#
API#
- bridge.training.utils.flop_utils._lora_seq_stats_cache: dict#
None
- bridge.training.utils.flop_utils.vit_flops(
- cfg: megatron.bridge.training.config.ConfigContainer,
- batch_size: int,
- num_patches: int,
Calculate FLOPs for a Vision Transformer (ViT) encoder + patch merger.
Includes:
ViT transformer layers (bidirectional full attention, not causal)
Patch merger (spatial merge + MLP projection to LLM hidden size)
- Parameters:
cfg – Configuration container. ViT hyper-parameters are read from
cfg.model.vision_config(depth,hidden_size,num_heads,intermediate_size,spatial_merge_size,out_hidden_size). Passing the whole config keeps the public signature stable as the list of required ViT attributes grows.batch_size – Batch size.
num_patches – Per-image number of vision patches (before spatial merge). Callers that track the total patch count across the batch should divide by
batch_sizebefore invoking, because ViT attention is per-image (not cross-image) and scales quadratically with the per-image patch count.
- Returns:
Total training FLOPs (forward * 3 for fwd+bwd). Returns 0 when no
vision_configis attached ornum_patchesis non-positive.
- bridge.training.utils.flop_utils.num_floating_point_operations(
- cfg: megatron.bridge.training.config.ConfigContainer,
- batch_size: int = 1,
- seqlen_sum: int | None = None,
- seqlen_squared_sum: int | None = None,
- num_vision_patches: int = 0,
Return the number of floating point operations.
- Parameters:
cfg – Configuration container.
batch_size – Batch size.
seqlen_sum – Sum of actual sequence lengths across the batch (batch_size * actual_seq_length). When provided, overrides cfg.model.seq_length for more accurate FLOPS estimation with dynamic-length sequences (e.g., VLM with dynamic padding).
seqlen_squared_sum – Sum of squared sequence lengths across the batch (sum_i actual_seq_length_i^2). Used for attention core FLOPS which scale quadratically with sequence length; when omitted, falls back to
batch_size * effective_seq_length^2so the result matches the legacy constant-length estimate.num_vision_patches – Total number of vision patches in the batch (before spatial merge). Used to compute ViT encoder FLOPS.