core.pipeline_parallel.bridge_communicator#

Module Contents#

Classes#

CommRole

Communication role for ranks in bridge communication.

RankCommInfo

Explicit communication plan for a single rank.

BridgeCommunicator

Pipeline Communicator between two modules with different(TP/DP/PP/CP).

API#

class core.pipeline_parallel.bridge_communicator.CommRole(*args, **kwds)#

Bases: enum.Enum

Communication role for ranks in bridge communication.

SENDER: Leader tp-cp rank within each DP replica of source grid. Sends data to destination grid receivers. RECEIVER: Leader tp-cp rank within each DP replica of destination grid. Receives data from source grid senders. MEMBER: Non-leader ranks within DP replicas. Participate in broadcasts from their local leader.

Initialization

SENDER#

‘SENDER’

RECEIVER#

‘RECEIVER’

MEMBER#

‘MEMBER’

class core.pipeline_parallel.bridge_communicator.RankCommInfo#

Explicit communication plan for a single rank.

role: core.pipeline_parallel.bridge_communicator.CommRole#

None

send_to_ranks: List[int]#

‘field(…)’

recv_from_ranks: List[int]#

‘field(…)’

class core.pipeline_parallel.bridge_communicator.BridgeCommunicator(
src_grid: megatron.core.hyper_comm_grid.HyperCommGrid,
dest_grid: megatron.core.hyper_comm_grid.HyperCommGrid,
dim_mapping: Optional[Dict[str, int]] = None,
comm_dtype: Optional[torch.dtype] = None,
src_module_name: Optional[str] = None,
dest_module_name: Optional[str] = None,
)#

Pipeline Communicator between two modules with different(TP/DP/PP/CP).

BridgeCommunicator:

  • Initialize the communicator between a pair of source and destination grids

  • Build a communication schedule for each rank

  • Provide public methods: send_forward, recv_forward, send_forward_recv_backward, send_backward_recv_forward to be used by the pipeline schedule.

Initialization

Initialize the bridge communicator between source and destination grids.

CP is not supported yet. Will be added in follow up PR.

Parameters:
  • src_grid – Source HyperCommGrid

  • dest_grid – Destination HyperCommGrid

  • dim_mapping – Dictionary mapping logical dimensions to tensor axes. Expected keys: ‘s’ (sequence), ‘b’ (batch), ‘h’ (hidden). Defaults to {‘s’: 1, ‘b’: 0, ‘h’: 2} if None.

get_leader_rank(
grid: megatron.core.hyper_comm_grid.HyperCommGrid,
is_src: bool,
) List[int]#

Get the leader rank for a given grid and direction.

We elect leader rank for each dp replica, the first tp-cp rank in the group in the last pp stage (for src grid) or first pp stage (for dest grid) is the leader.

get_boundary_pp_stage_ranks(
grid: megatron.core.hyper_comm_grid.HyperCommGrid,
is_src: bool,
)#

Get TP-CP ranks at boundary PP stage for each DP replica.

Returns ranks at the last PP stage (if src) or first PP stage (if dest) for each DP dimension, ordered by DP dimension.

is_current_rank_in_grid(
grid: megatron.core.hyper_comm_grid.HyperCommGrid,
) bool#

Check if the current rank is in the grid.

build_comm_map(
src_tp_leaders: List[int],
dest_tp_leaders: List[int],
)#

Get src/dest tp leaders and populate comm_map for each rank.

This method analyzes the source and destination grids to determine which ranks need to send/receive data and builds the communication schedule accordingly.

send_forward(tensor_to_send: torch.Tensor)#

Send forward activation tensor.

Parameters:

tensor_to_send – The tensor to send to the destination grid

recv_forward() torch.Tensor#

Receive forward activation tensor.

Parameters:

tensor_shape – Expected tensor shape (None if using shape communication)

Returns:

The received activation tensor

Return type:

torch.Tensor

send_backward(grad_tensor: torch.Tensor)#

Send backward gradient tensor.

Note: Gradient senders are activation ‘RECEIVERS’

Parameters:

grad_tensor – The gradient tensor to send back

recv_backward() torch.Tensor#

Receive backward gradient tensor.

Note: Gradient receivers are activation ‘SENDERS’

Parameters:

tensor_shape – Expected gradient tensor shape

Returns:

The received gradient tensor

Return type:

torch.Tensor

send_forward_recv_backward(
input_tensor: torch.Tensor,
grad_shape: Optional[Tuple[int, ...]] = None,
) torch.Tensor#

Combined operation: send forward activation and receive backward gradient.

Parameters:
  • input_tensor – The tensor to send forward

  • grad_shape – Expected gradient tensor shape

Returns:

The received gradient tensor

Return type:

torch.Tensor

send_backward_recv_forward(
grad_tensor: torch.Tensor,
forward_shape: Optional[Tuple[int, ...]] = None,
) torch.Tensor#

Combined operation: send backward gradient and receive forward activation.

Parameters:
  • grad_tensor – The gradient tensor to send backward

  • forward_shape – Expected forward tensor shape

Returns:

The received activation tensor

Return type:

torch.Tensor

_communicate_shapes(
tensor_to_send_next: Optional[torch.Tensor] = None,
recv_next: bool = False,
recv_prev: bool = False,
tensor_to_send_prev: Optional[torch.Tensor] = None,
) Tuple[List[Tuple[int, ...]], List[Tuple[int, ...]]]#

Communicate tensor shapes between sender and receiver ranks in the bridge.

This is used to communicate tensor shapes before actual tensor communication when dealing with variable sequence lengths or dynamic shapes.

Parameters:
  • tensor_to_send_next – The tensor to send to the next rank (None if not sending)

  • tensor_to_send_prev – The tensor to send to the previous rank (None if not sending)

  • recv_next – Whether to receive from the next rank (None if not receiving)

  • recv_prev – Whether to receive from the previous rank (None if not receiving)

Returns:

  • List of forward shapes that will be received (empty if not a receiver)

  • List of gradient shapes that will be received (empty if not expecting gradients)

Return type:

Tuple containing

_split_tensor_at_batch_dim(
aggregated_tensor: torch.Tensor,
num_splits: int,
) List[torch.Tensor]#

Split an aggregated tensor into multiple tensors at the batch dimension.

Parameters:
  • aggregated_tensor – The tensor to split

  • num_splits – The number of splits to create

Returns:

List of tensors split at the batch dimension