core.pipeline_parallel.multimodule_communicator#

Module Contents#

Classes#

RankModuleInfo

Information about a rank in a module.

MultiModulePipelineCommunicator

Communicator for a multi-module pipeline.

Data#

API#

core.pipeline_parallel.multimodule_communicator.Shape#

None

class core.pipeline_parallel.multimodule_communicator.RankModuleInfo#

Information about a rank in a module.

.. attribute:: pp_rank

The stage index of the current rank within the module’s pipeline.

.. attribute:: pp_size

The total number of pipeline stages (ranks) in the module.

.. attribute:: p2p_communicator

Intra-module point-to-point communicator.

.. attribute:: bridge_comms_as_src_module

Bridge communicators for outgoing connections from this module to downstream modules. One module may have multiple bridge communicators if it has multiple outgoing connections.

.. attribute:: bridge_comms_as_dest_module

Bridge communicators for incoming connections to this module from upstream modules. One module may have multiple bridge communicators if it has multiple incoming connections.

.. attribute:: is_source_stage

True if this rank is at the absolute first stage in the overall model (no incoming connections).

.. attribute:: is_terminal_stage

True if this rank is at the absolute last stage in the overall model (no outgoing connections).

pp_rank: int#

None

pp_size: int#

None

p2p_communicator: Optional[megatron.core.pipeline_parallel.p2p_communication.P2PCommunicator]#

None

bridge_comms_as_src_module: Optional[List[megatron.core.pipeline_parallel.bridge_communicator.BridgeCommunicator]]#

None

bridge_comms_as_dest_module: Optional[List[megatron.core.pipeline_parallel.bridge_communicator.BridgeCommunicator]]#

None

is_source_stage: Optional[bool]#

True

is_terminal_stage: Optional[bool]#

True

class core.pipeline_parallel.multimodule_communicator.MultiModulePipelineCommunicator(
module_to_grid_map: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
topology: Dict[str, List[str]],
config: megatron.core.model_parallel_config.ModelParallelConfig,
dim_mapping: Dict[str, List[int]] = None,
)#

Communicator for a multi-module pipeline.

Initialization

Initialize the MultiModulePipelineCommunicator.

Parameters:
  • module_to_grid_map (dict) –

    A dictionary mapping module names to HyperCommGrids. .. rubric:: Example

    module_to_grid_map = { ‘image_encoder’: image_encoder_grid, ‘audio_encoder’: audio_encoder_grid, ‘llm’: llm_grid, ‘generator’: generator_grid }

  • topology (dict) –

    A dictionary mapping module names to lists of outgoing modules. .. rubric:: Example

    topology = { ‘image_encoder’: [‘llm’], ‘audio_encoder’: [‘llm’], ‘llm’: [‘generator’], ‘generator’: [] }

  • config (ModelParallelConfig) – A ModelParallelConfig object.

  • dim_mapping (Dict[str, List[int]]) –

    Dimension mapping for sequence, batch, hidden. .. rubric:: Example

    dim_mapping = {‘s’: 0, ‘h’: 2, ‘b’: 1}

    Default: None

_build_bridge_comms()#

Construct and store BridgeCommunicator objects that describe the outgoing communication relationships for all of the modules.

property is_pp_first_stage#

Return True if the current rank has the absolute first stage in the overall model.

The absolute first stage is defined as:

  1. The current rank must be in the first PP stage (pp_rank == 0) of some module

  2. That module must be a source module (no incoming connections in topology)

property is_pp_last_stage#

Return True if the current rank has the absolute last stage in the overall model.

The absolute last stage is defined as:

  1. The current rank must be in the last PP stage of some module

  2. That module must be a sink module (no outgoing connections in topology)

_is_source_module(module_name: str) bool#

Check if a module is a source module (has no incoming connections).

_is_sink_module(module_name: str) bool#

Check if a module is a sink module (has no outgoing connections).

is_current_rank_in_grid(
grid: megatron.core.hyper_comm_grid.HyperCommGrid,
) bool#

Check if the current rank is in the grid.

property num_warmup_microbatches#

Calculate the number of warmup microbatches for the current rank.

Uses the same simple logic as P2PCommunicator: total_pipeline_stages - current_rank_stage - 1

Returns:

Number of warmup microbatches for this rank

Return type:

int

_build_rank_module_info_map()#

For each module in the current rank, initialize the P2P communicator and build the bridge communicator info for the module. Each rank may hold multiple modules when colocated.

recv_forward(
tensor_shape: Optional[core.pipeline_parallel.multimodule_communicator.Shape] = None,
is_first_stage: bool = False,
) Dict[str, torch.Tensor]#

Receive forward activation tensor.

Parameters:

tensor_shape – Expected activation tensor shape

Returns:

A dictionary mapping module names to tensors.

send_forward(
output_dict: Dict[str, torch.Tensor],
is_last_stage: bool = False,
)#

Send forward activation tensor.

Parameters:

output_dict – A dictionary mapping module names to tensors.

send_forward_recv_backward(
output_dict: Dict[str, torch.Tensor],
tensor_shape: Optional[core.pipeline_parallel.multimodule_communicator.Shape] = None,
is_last_stage: bool = False,
) Dict[str, torch.Tensor]#

Send forward activation tensor and receive backward activation tensor.

Parameters:
  • output_dict – A dictionary mapping module names to tensors.

  • tensor_shape – Expected gradient tensor shape

Returns:

A dictionary mapping module names to tensors.

send_backward_recv_forward(
grad_dict: Dict[str, torch.Tensor],
tensor_shape: Optional[core.pipeline_parallel.multimodule_communicator.Shape] = None,
is_first_stage: bool = False,
) Dict[str, torch.Tensor]#

Send backward activation tensor and receive forward activation tensor.

Parameters:
  • grad_dict – A dictionary mapping module names to tensors.

  • tensor_shape – Expected gradient tensor shape

Returns:

A dictionary mapping module names to tensors.

recv_backward(
tensor_shape: Optional[core.pipeline_parallel.multimodule_communicator.Shape] = None,
is_last_stage: bool = False,
) Dict[str, torch.Tensor]#

Receive backward activation tensor.

Parameters:

tensor_shape – Expected gradient tensor shape

Returns:

A dictionary mapping module names to tensors.

send_backward(
grad_dict: Dict[str, torch.Tensor],
is_first_stage: bool = False,
)#

Send backward activation tensor.

Parameters:

grad_dict – A dictionary mapping module names to tensors.

static compute_total_pipeline_stages(
topology: Dict[str, List[str]],
module_to_grid_map: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
rank: Optional[int] = None,
module_name: Optional[str] = None,
) int#

Compute the total number of pipeline stages across a multi-module chain.

Interprets topology as a directed acyclic graph (DAG) where nodes are modules and edges indicate forward data flow from source to destination modules. Each node is assigned a weight equal to its pipeline parallel size (number of PP stages).

The total number of stages is defined as the length of the longest path in this DAG under node weights.

If rank is None (default), returns the maximum over all terminal (sink) modules of the sum of PP sizes along a path ending at that terminal. For example, given:

image_encoder ->                              -> llm -> generator
audio_encoder  ->/

the total is: max(pp(image_encoder), pp(audio_encoder)) + pp(llm) + pp(generator).

If rank is provided, the result is the total number of pipeline stages up to (and including) the PP stage that rank occupies inside its module. In this case, the weight of the target module equals (pp_rank_index(rank) + 1) instead of the module’s full PP size; other modules still contribute their full PP sizes. If the rank belongs to multiple modules (colocation), pass module_name to disambiguate; otherwise the maximum across all candidate modules containing the rank is returned.

Parameters:
  • topology – Mapping from a module to its list of outgoing modules.

  • module_to_grid_map – Mapping from module name to its HyperCommGrid.

Returns:

The total number of pipeline stages along the longest path given the constraints.

Raises:

ValueError – If the topology contains cycles; or has no terminal nodes when rank is None