core.models.mimo.comm.colocated_communicator#

Module Contents#

Classes#

SliceInfo

Batch dimension slice information for a rank’s data partition.

BridgeDirection

Which side of the bridge scales up, if any.

ColocatedBridgeCommunicator

Bridges tensors between colocated modules with different TP/DP layouts.

_ColocatedCommunicate

Autograd function for colocated communication with correct backward pass.

Functions#

_all_gather_along_batch_dim

All-gather tensor along an arbitrary batch dim into a single tensor.

API#

class core.models.mimo.comm.colocated_communicator.SliceInfo#

Batch dimension slice information for a rank’s data partition.

start: int#

None

size: int#

None

class core.models.mimo.comm.colocated_communicator.BridgeDirection#

Bases: str, enum.Enum

Which side of the bridge scales up, if any.

FAN_IN — src has more DP replicas than dest; forward all-gathers src outputs along the batch dim, backward narrows the sibling dest gradient down to this src rank’s slot.

FAN_OUT — dest has more DP replicas; forward narrows, backward all-gathers across the sibling dest DP ranks (the adjoint of narrow is not zero-pad-and-scatter because every dest rank consumes a different slice of the same src activation).

EQUAL — matching DP; the bridge is a pure passthrough.

Initialization

Initialize self. See help(type(self)) for accurate signature.

FAN_IN#

‘fan_in’

FAN_OUT#

‘fan_out’

EQUAL#

‘equal’

class core.models.mimo.comm.colocated_communicator.ColocatedBridgeCommunicator(
src_grid: megatron.core.hyper_comm_grid.HyperCommGrid,
dest_grid: megatron.core.hyper_comm_grid.HyperCommGrid,
src_module_name: str = 'src',
dest_module_name: str = 'dest',
dim_mapping: Optional[Dict[str, int]] = None,
)#

Bridges tensors between colocated modules with different TP/DP layouts.

Default dim_mapping assumes 3D (b, s, h). Callers bridging MimoModel’s pre-flattened (s*b, h) encoder output should pass dim_mapping={'b': 0, 'h': 1}; this relies on a uniform token count per sample so dim 0 divides evenly by the DP scale.

Precondition: the input must be TP-replicated across the src TP group — i.e. all TP ranks inside a src DP replica hold the same tensor on the batch dim. The bridge never gathers along TP; violating this silently produces wrong results.

Initialization

_validate_grids()#
_extract_parallelism_info()#
_build_rank_mappings()#
static _build_gather_groups(
iter_size: int,
sibling_tp_size: int,
scale: int,
rank_to_pos: Dict[int, Tuple[int, int]],
) List[List[int]]#

Build iter_size * sibling_tp_size gather groups of scale ranks.

For each slot on the “iterating” side and each TP shard on the sibling side, collect the scale sibling ranks whose DP indices map into that slot. Append order equals group-local-rank order, which all_gather_into_tensor uses to concatenate outputs — do not sort.

is_fan_in() bool#

True if src DP > dest DP (forward all-gathers).

is_fan_out() bool#

True if src DP < dest DP (forward narrows).

get_slice_info(
batch_size: int,
) core.models.mimo.comm.colocated_communicator.SliceInfo#

Compute this rank’s slice of batch_size on the narrowing side.

For FAN_OUT this is the forward narrow; for FAN_IN it is the backward narrow against the post-gather batch. EQUAL returns the identity slice.

Raises ValueError if batch_size is not divisible by scale.

_check_divisible(batch_size: int) None#
communicate(tensor: torch.Tensor) torch.Tensor#

Transform tensor from src TP/DP layout to dest TP/DP layout.

Raises ValueError when FAN_OUT and the batch dim is not divisible by scale; FAN_IN only slices on the backward pass and re-checks via get_slice_info there.

destroy() None#

Release the NCCL subgroup created by this communicator.

NCCL caps concurrent communicators; long-lived or repeated construction leaks PGs without this call.

class core.models.mimo.comm.colocated_communicator._ColocatedCommunicate#

Bases: torch.autograd.Function

Autograd function for colocated communication with correct backward pass.

static forward(
ctx,
tensor: torch.Tensor,
comm: core.models.mimo.comm.colocated_communicator.ColocatedBridgeCommunicator,
) torch.Tensor#

Reshape the batch dim across the bridge: narrow on fan-out, all-gather on fan-in.

static backward(
ctx,
grad_output: torch.Tensor,
) Tuple[torch.Tensor, None]#

Adjoint of forward: narrow for fan-in, all-gather for fan-out.

Fan-out’s forward is narrow, whose naive adjoint is zero-pad. That would leave each src rank with only its own dest rank’s slice of the gradient, missing the contributions from every other dest rank that consumed a different slice of the same src activation. Instead we all-gather across the fan-out sibling group, reconstructing the full src-batch gradient (symmetric with the fan-in forward’s all-gather).

core.models.mimo.comm.colocated_communicator._all_gather_along_batch_dim(
tensor: torch.Tensor,
group: torch.distributed.ProcessGroup,
batch_dim: int,
) torch.Tensor#

All-gather tensor along an arbitrary batch dim into a single tensor.

all_gather_into_tensor concatenates along dim 0, so when the batch dim is not 0 we move it, gather, then restore.