bridge.models.mimo.mimo_builder#
Module Contents#
Functions#
Create HyperCommGrid objects per module from MIMO parallelism config. |
|
Infer a default multi-encoder -> LLM topology. |
|
Create embedding-related process groups from PP group ranks. |
|
Check if the current rank participates in this grid. |
API#
- bridge.models.mimo.mimo_builder.build_hypercomm_grids(
- mimo_parallelism_config: megatron.bridge.models.mimo.mimo_config.MimoParallelismConfig,
Create HyperCommGrid objects per module from MIMO parallelism config.
Creates grids on ALL ranks (required for consistent collective calls), but only ranks in each grid’s range will participate in its operations.
- Parameters:
mimo_parallelism_config – MimoParallelismConfig specifying parallelism per module.
- Returns:
Dict mapping module names to their HyperCommGrids.
- bridge.models.mimo.mimo_builder._default_topology(
- mimo_parallelism_config: megatron.bridge.models.mimo.mimo_config.MimoParallelismConfig,
Infer a default multi-encoder -> LLM topology.
- bridge.models.mimo.mimo_builder.create_embedding_and_position_groups(
- pp_group: torch.distributed.ProcessGroup,
Create embedding-related process groups from PP group ranks.
Following MCore semantics:
pos_embd_pg: Only rank 0 of PP (first stage) - for position embeddings
embd_pg: Ranks 0 and -1 of PP (first and last stages) - for tied word embeddings
IMPORTANT: This calls dist.new_group which is a collective operation. Must be called on all ranks that could participate.
Note: VPP (virtual_pipeline_model_parallel_size > 1) is not supported. With VPP, pp_ranks[0]/pp_ranks[-1] do not reliably identify the stages that own embeddings. The caller is responsible for asserting VPP is disabled.
- Parameters:
pp_group – The pipeline parallel process group.
- Returns:
Tuple of (pos_embd_pg, embd_pg). Returns (None, None) if pp_group is None.
- bridge.models.mimo.mimo_builder.is_current_rank_in_grid(
- grid: megatron.core.hyper_comm_grid.HyperCommGrid,
Check if the current rank participates in this grid.
- Parameters:
grid – A HyperCommGrid instance.
- Returns:
True if dist.get_rank() is within the grid’s rank range.