core.resharding.execution#

Module Contents#

Classes#

_Writeback

Tagged-union for what to do with a received tensor after service.run().

Functions#

_get_mxfp8_accumulator

Get or lazily allocate the BF16 accumulation buffer for an MXFP8 dest param.

execute_reshard_plan

Execute a reshard plan (from centralized controller). A communication service must be provided to abstract transport. Expected service API: submit_send(tensor, dest_rank, task_id), submit_recv(tensor, src_rank, task_id), run().

Data#

API#

core.resharding.execution.logger#

‘getLogger(…)’

class core.resharding.execution._Writeback#

Tagged-union for what to do with a received tensor after service.run().

Exactly one of the three kinds applies; the other fields are unused for that kind. direct means the data landed in its final destination during recv and there’s nothing to copy. copy copies a staging recv_buffer into a slice of dst_param (deferring to MXFP8 accumulation when the dest is quantized). transform hands the received buffers to a ReshardTransform.finalize_recv call.

kind: str#

None

recv_buffer: Optional[torch.Tensor]#

None

dst_param: Optional[torch.Tensor]#

None

dst_slice: Optional[tuple]#

None

param_name: Optional[str]#

None

recv_bufs: Optional[list[torch.Tensor]]#

None

core.resharding.execution._get_mxfp8_accumulator(
pending: dict[int, tuple],
dst_param: torch.Tensor,
) tuple[torch.Tensor, list]#

Get or lazily allocate the BF16 accumulation buffer for an MXFP8 dest param.

All slices for the same dst_param land in this buffer; quantize_ is called once after all slices have been written. Allocates empty (not dequantized) because every slice will be overwritten.

core.resharding.execution.execute_reshard_plan(
plan: core.resharding.utils.ReshardPlan,
src_module: torch.nn.Module,
dst_module: torch.nn.Module,
service: core.resharding.copy_services.base.CopyService,
group=None,
transform: Optional[core.resharding.transforms.ReshardTransform] = None,
) None#

Execute a reshard plan (from centralized controller). A communication service must be provided to abstract transport. Expected service API: submit_send(tensor, dest_rank, task_id), submit_recv(tensor, src_rank, task_id), run().

Supports None for src_module and/or dst_module to allow ranks in non-collocated mode:

  • src_module=None: Rank only receives data (destination-only)

  • dst_module=None: Rank only sends data (source-only)

  • Both provided: Rank participates in both send and recv (collocated mode)

When transform is provided, parameters for which transform.should_transform(param_name) returns True use the transform’s prepare_send / prepare_recv / finalize_recv methods instead of the default slice-and-copy logic.