core.models.mimo.comm.colocated_communicator#
Module Contents#
Classes#
Batch dimension slice information for a rank’s data partition. |
|
Which side of the bridge scales up, if any. |
|
Bridges tensors between colocated modules with different TP/DP layouts. |
|
Autograd function for colocated communication with correct backward pass. |
Functions#
All-gather |
API#
- class core.models.mimo.comm.colocated_communicator.SliceInfo#
Batch dimension slice information for a rank’s data partition.
- start: int#
None
- size: int#
None
- class core.models.mimo.comm.colocated_communicator.BridgeDirection#
Bases:
str,enum.EnumWhich side of the bridge scales up, if any.
FAN_IN— src has more DP replicas than dest; forward all-gathers src outputs along the batch dim, backward narrows the sibling dest gradient down to this src rank’s slot.FAN_OUT— dest has more DP replicas; forward narrows, backward all-gathers across the sibling dest DP ranks (the adjoint of narrow is not zero-pad-and-scatter because every dest rank consumes a different slice of the same src activation).EQUAL— matching DP; the bridge is a pure passthrough.Initialization
Initialize self. See help(type(self)) for accurate signature.
- FAN_IN#
‘fan_in’
- FAN_OUT#
‘fan_out’
- EQUAL#
‘equal’
- class core.models.mimo.comm.colocated_communicator.ColocatedBridgeCommunicator(
- src_grid: megatron.core.hyper_comm_grid.HyperCommGrid,
- dest_grid: megatron.core.hyper_comm_grid.HyperCommGrid,
- src_module_name: str = 'src',
- dest_module_name: str = 'dest',
- dim_mapping: Optional[Dict[str, int]] = None,
Bridges tensors between colocated modules with different TP/DP layouts.
Default
dim_mappingassumes 3D(b, s, h). Callers bridgingMimoModel’s pre-flattened(s*b, h)encoder output should passdim_mapping={'b': 0, 'h': 1}; this relies on a uniform token count per sample so dim 0 divides evenly by the DP scale.Precondition: the input must be TP-replicated across the src TP group — i.e. all TP ranks inside a src DP replica hold the same tensor on the batch dim. The bridge never gathers along TP; violating this silently produces wrong results.
Initialization
- _validate_grids()#
- _extract_parallelism_info()#
- _build_rank_mappings()#
- static _build_gather_groups(
- iter_size: int,
- sibling_tp_size: int,
- scale: int,
- rank_to_pos: Dict[int, Tuple[int, int]],
Build
iter_size * sibling_tp_sizegather groups ofscaleranks.For each slot on the “iterating” side and each TP shard on the sibling side, collect the
scalesibling ranks whose DP indices map into that slot. Append order equals group-local-rank order, whichall_gather_into_tensoruses to concatenate outputs — do not sort.
- is_fan_in() bool#
True if src DP > dest DP (forward all-gathers).
- is_fan_out() bool#
True if src DP < dest DP (forward narrows).
- get_slice_info(
- batch_size: int,
Compute this rank’s slice of
batch_sizeon the narrowing side.For FAN_OUT this is the forward narrow; for FAN_IN it is the backward narrow against the post-gather batch. EQUAL returns the identity slice.
Raises
ValueErrorifbatch_sizeis not divisible byscale.
- _check_divisible(batch_size: int) None#
- communicate(tensor: torch.Tensor) torch.Tensor#
Transform
tensorfrom src TP/DP layout to dest TP/DP layout.Raises
ValueErrorwhen FAN_OUT and the batch dim is not divisible byscale; FAN_IN only slices on the backward pass and re-checks viaget_slice_infothere.
- destroy() None#
Release the NCCL subgroup created by this communicator.
NCCL caps concurrent communicators; long-lived or repeated construction leaks PGs without this call.
- class core.models.mimo.comm.colocated_communicator._ColocatedCommunicate#
Bases:
torch.autograd.FunctionAutograd function for colocated communication with correct backward pass.
- static forward(
- ctx,
- tensor: torch.Tensor,
- comm: core.models.mimo.comm.colocated_communicator.ColocatedBridgeCommunicator,
Reshape the batch dim across the bridge: narrow on fan-out, all-gather on fan-in.
- static backward(
- ctx,
- grad_output: torch.Tensor,
Adjoint of forward: narrow for fan-in, all-gather for fan-out.
Fan-out’s forward is
narrow, whose naive adjoint is zero-pad. That would leave each src rank with only its own dest rank’s slice of the gradient, missing the contributions from every other dest rank that consumed a different slice of the same src activation. Instead we all-gather across the fan-out sibling group, reconstructing the full src-batch gradient (symmetric with the fan-in forward’s all-gather).
- core.models.mimo.comm.colocated_communicator._all_gather_along_batch_dim(
- tensor: torch.Tensor,
- group: torch.distributed.ProcessGroup,
- batch_dim: int,
All-gather
tensoralong an arbitrary batch dim into a single tensor.all_gather_into_tensorconcatenates along dim 0, so when the batch dim is not 0 we move it, gather, then restore.