bridge.models.mimo.mimo_builder#

Module Contents#

Functions#

build_hypercomm_grids

Create HyperCommGrid objects per module from MIMO parallelism config.

populate_embedding_and_position_groups

Create embedding-related process groups from PP group ranks.

is_pp_first_stage

Check if current rank is first stage in pipeline.

is_pp_last_stage

Check if current rank is last stage in pipeline.

API#

bridge.models.mimo.mimo_builder.build_hypercomm_grids(
mimo_parallelism_config: megatron.bridge.models.mimo.mimo_config.MimoParallelismConfig,
) Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid]#

Create HyperCommGrid objects per module from MIMO parallelism config.

Creates grids on ALL ranks (required for consistent collective calls), but only ranks in each grid’s range will participate in its operations.

Parameters:

mimo_parallelism_config – MimoParallelismConfig specifying parallelism per module.

Returns:

Dict mapping module names to their HyperCommGrids.

bridge.models.mimo.mimo_builder.populate_embedding_and_position_groups(
pp_group: torch.distributed.ProcessGroup,
) Tuple[Optional[torch.distributed.ProcessGroup], Optional[torch.distributed.ProcessGroup]]#

Create embedding-related process groups from PP group ranks.

Following MCore semantics:

  • pos_embd_pg: Only rank 0 of PP (first stage) - for position embeddings

  • embd_pg: Ranks 0 and -1 of PP (first and last stages) - for tied word embeddings

IMPORTANT: This calls dist.new_group which is a collective operation. Must be called on all ranks that could participate.

Parameters:

pp_group – The pipeline parallel process group.

Returns:

Tuple of (pos_embd_pg, embd_pg). Returns (None, None) if pp_group is None.

bridge.models.mimo.mimo_builder.is_pp_first_stage(
pp_group: Optional[torch.distributed.ProcessGroup],
) bool#

Check if current rank is first stage in pipeline.

bridge.models.mimo.mimo_builder.is_pp_last_stage(
pp_group: Optional[torch.distributed.ProcessGroup],
) bool#

Check if current rank is last stage in pipeline.