core.pipeline_parallel.bridge_communicator#
Module Contents#
Classes#
Communication role for ranks in bridge communication. |
|
Explicit communication plan for a single rank. |
|
Pipeline Communicator between two modules with different(TP/DP/PP/CP). |
API#
- class core.pipeline_parallel.bridge_communicator.CommRole(*args, **kwds)#
Bases:
enum.EnumCommunication role for ranks in bridge communication.
SENDER: Leader tp-cp rank within each DP replica of source grid. Sends data to destination grid receivers. RECEIVER: Leader tp-cp rank within each DP replica of destination grid. Receives data from source grid senders. MEMBER: Non-leader ranks within DP replicas. Participate in broadcasts from their local leader.
Initialization
- SENDER#
‘SENDER’
- RECEIVER#
‘RECEIVER’
- MEMBER#
‘MEMBER’
- class core.pipeline_parallel.bridge_communicator.RankCommInfo#
Explicit communication plan for a single rank.
- send_to_ranks: List[int]#
‘field(…)’
- recv_from_ranks: List[int]#
‘field(…)’
- class core.pipeline_parallel.bridge_communicator.BridgeCommunicator(
- src_grid: megatron.core.hyper_comm_grid.HyperCommGrid,
- dest_grid: megatron.core.hyper_comm_grid.HyperCommGrid,
- dim_mapping: Optional[Dict[str, int]] = None,
- comm_dtype: Optional[torch.dtype] = None,
- src_module_name: Optional[str] = None,
- dest_module_name: Optional[str] = None,
Pipeline Communicator between two modules with different(TP/DP/PP/CP).
BridgeCommunicator:
Initialize the communicator between a pair of source and destination grids
Build a communication schedule for each rank
Provide public methods: send_forward, recv_forward, send_forward_recv_backward, send_backward_recv_forward to be used by the pipeline schedule.
Initialization
Initialize the bridge communicator between source and destination grids.
CP is not supported yet. Will be added in follow up PR.
- Parameters:
src_grid – Source HyperCommGrid
dest_grid – Destination HyperCommGrid
dim_mapping – Dictionary mapping logical dimensions to tensor axes. Expected keys: ‘s’ (sequence), ‘b’ (batch), ‘h’ (hidden). Defaults to {‘s’: 1, ‘b’: 0, ‘h’: 2} if None.
- get_leader_rank(
- grid: megatron.core.hyper_comm_grid.HyperCommGrid,
- is_src: bool,
Get the leader rank for a given grid and direction.
We elect leader rank for each dp replica, the first tp-cp rank in the group in the last pp stage (for src grid) or first pp stage (for dest grid) is the leader.
- get_boundary_pp_stage_ranks(
- grid: megatron.core.hyper_comm_grid.HyperCommGrid,
- is_src: bool,
Get TP-CP ranks at boundary PP stage for each DP replica.
Returns ranks at the last PP stage (if src) or first PP stage (if dest) for each DP dimension, ordered by DP dimension.
- is_current_rank_in_grid(
- grid: megatron.core.hyper_comm_grid.HyperCommGrid,
Check if the current rank is in the grid.
- build_comm_map(
- src_tp_leaders: List[int],
- dest_tp_leaders: List[int],
Get src/dest tp leaders and populate comm_map for each rank.
This method analyzes the source and destination grids to determine which ranks need to send/receive data and builds the communication schedule accordingly.
- send_forward(tensor_to_send: torch.Tensor)#
Send forward activation tensor.
- Parameters:
tensor_to_send – The tensor to send to the destination grid
- recv_forward() torch.Tensor#
Receive forward activation tensor.
- Parameters:
tensor_shape – Expected tensor shape (None if using shape communication)
- Returns:
The received activation tensor
- Return type:
torch.Tensor
- send_backward(grad_tensor: torch.Tensor)#
Send backward gradient tensor.
Note: Gradient senders are activation ‘RECEIVERS’
- Parameters:
grad_tensor – The gradient tensor to send back
- recv_backward() torch.Tensor#
Receive backward gradient tensor.
Note: Gradient receivers are activation ‘SENDERS’
- Parameters:
tensor_shape – Expected gradient tensor shape
- Returns:
The received gradient tensor
- Return type:
torch.Tensor
- send_forward_recv_backward(
- input_tensor: torch.Tensor,
- grad_shape: Optional[Tuple[int, ...]] = None,
Combined operation: send forward activation and receive backward gradient.
- Parameters:
input_tensor – The tensor to send forward
grad_shape – Expected gradient tensor shape
- Returns:
The received gradient tensor
- Return type:
torch.Tensor
- send_backward_recv_forward(
- grad_tensor: torch.Tensor,
- forward_shape: Optional[Tuple[int, ...]] = None,
Combined operation: send backward gradient and receive forward activation.
- Parameters:
grad_tensor – The gradient tensor to send backward
forward_shape – Expected forward tensor shape
- Returns:
The received activation tensor
- Return type:
torch.Tensor
- _communicate_shapes(
- tensor_to_send_next: Optional[torch.Tensor] = None,
- recv_next: bool = False,
- recv_prev: bool = False,
- tensor_to_send_prev: Optional[torch.Tensor] = None,
Communicate tensor shapes between sender and receiver ranks in the bridge.
This is used to communicate tensor shapes before actual tensor communication when dealing with variable sequence lengths or dynamic shapes.
- Parameters:
tensor_to_send_next – The tensor to send to the next rank (None if not sending)
tensor_to_send_prev – The tensor to send to the previous rank (None if not sending)
recv_next – Whether to receive from the next rank (None if not receiving)
recv_prev – Whether to receive from the previous rank (None if not receiving)
- Returns:
List of forward shapes that will be received (empty if not a receiver)
List of gradient shapes that will be received (empty if not expecting gradients)
- Return type:
Tuple containing
- _split_tensor_at_batch_dim(
- aggregated_tensor: torch.Tensor,
- num_splits: int,
Split an aggregated tensor into multiple tensors at the batch dimension.
- Parameters:
aggregated_tensor – The tensor to split
num_splits – The number of splits to create
- Returns:
List of tensors split at the batch dimension