core.resharding.execution#
Module Contents#
Classes#
Tagged-union for what to do with a received tensor after service.run(). |
Functions#
Get or lazily allocate the BF16 accumulation buffer for an MXFP8 dest param. |
|
Execute a reshard plan (from centralized controller). A communication service must be provided to abstract transport. Expected service API: submit_send(tensor, dest_rank, task_id), submit_recv(tensor, src_rank, task_id), run(). |
Data#
API#
- core.resharding.execution.logger#
‘getLogger(…)’
- class core.resharding.execution._Writeback#
Tagged-union for what to do with a received tensor after service.run().
Exactly one of the three kinds applies; the other fields are unused for that kind.
directmeans the data landed in its final destination during recv and there’s nothing to copy.copycopies a stagingrecv_bufferinto a slice ofdst_param(deferring to MXFP8 accumulation when the dest is quantized).transformhands the received buffers to aReshardTransform.finalize_recvcall.- kind: str#
None
- recv_buffer: Optional[torch.Tensor]#
None
- dst_param: Optional[torch.Tensor]#
None
- dst_slice: Optional[tuple]#
None
- param_name: Optional[str]#
None
- recv_bufs: Optional[list[torch.Tensor]]#
None
- core.resharding.execution._get_mxfp8_accumulator(
- pending: dict[int, tuple],
- dst_param: torch.Tensor,
Get or lazily allocate the BF16 accumulation buffer for an MXFP8 dest param.
All slices for the same dst_param land in this buffer;
quantize_is called once after all slices have been written. Allocates empty (not dequantized) because every slice will be overwritten.
- core.resharding.execution.execute_reshard_plan(
- plan: core.resharding.utils.ReshardPlan,
- src_module: torch.nn.Module,
- dst_module: torch.nn.Module,
- service: core.resharding.copy_services.base.CopyService,
- group=None,
- transform: Optional[core.resharding.transforms.ReshardTransform] = None,
Execute a reshard plan (from centralized controller). A communication service must be provided to abstract transport. Expected service API: submit_send(tensor, dest_rank, task_id), submit_recv(tensor, src_rank, task_id), run().
Supports None for src_module and/or dst_module to allow ranks in non-collocated mode:
src_module=None: Rank only receives data (destination-only)
dst_module=None: Rank only sends data (source-only)
Both provided: Rank participates in both send and recv (collocated mode)
When transform is provided, parameters for which
transform.should_transform(param_name)returns True use the transform’s prepare_send / prepare_recv / finalize_recv methods instead of the default slice-and-copy logic.