core.resharding.transforms#
Module Contents#
Classes#
Hook for custom send/recv/writeback during reshard execution. |
|
MXFP8 format-conversion transform for reshard. |
Functions#
Convert an MXFP8 data slice to the corresponding scale slice. |
|
Return a standard-dtype tensor suitable for wire transmission. |
API#
- class core.resharding.transforms.ReshardTransform#
Hook for custom send/recv/writeback during reshard execution.
Implementations override the four methods below. When an instance is passed to
execute_reshard_plan, eachTransferOpis checked viashould_transform(param_name); if True the transform methods are used instead of the default send/recv/writeback logic.The transform may change the wire format (e.g. send MXFP8 data+scale instead of BF16) or keep the same wire format and only post-process on the receive side (e.g. receive BF16, convert to MXFP8 in
finalize_recv). The only constraint is thatprepare_sendandprepare_recvmust return the same number of tensors for a given parameter so that send/recv pairs match.- should_transform(param_name: str) bool#
Return True if param_name should use the transform path.
- abstractmethod prepare_send(
- param_name: str,
- src_slice: tuple[slice, ...],
- src_param: torch.nn.Parameter,
Produce tensor(s) to send for param_name.
May return multiple tensors (e.g. data + scale when converting to MXFP8 on the sender side). The default implementation sends the BF16 slice unchanged (single tensor).
- abstractmethod prepare_recv(
- param_name: str,
- dst_slice: tuple[slice, ...],
Allocate receive buffer(s). Count must match
prepare_sendoutput.
- abstractmethod finalize_recv(
- param_name: str,
- dst_slice: tuple[slice, ...],
- recv_buffers: list[torch.Tensor],
Write received data into final destination (e.g. persistent buffers).
This is where receiver-side format conversion can happen (e.g. converting a BF16 recv buffer to MXFP8 before writing into persistent storage).
- core.resharding.transforms._scale_slice_from_data_slice(
- data_slice: tuple[slice, ...],
- block_size: int = 32,
Convert an MXFP8 data slice to the corresponding scale slice.
In MXFP8, each group of
block_sizeelements along the last (K) dimension shares a single scale value. All dimensions except the last are passed through unchanged; the lastslicehas its start/stop divided byblock_size. Integer index on the last dim is converted to scale index as idx // block_size.
- core.resharding.transforms._ensure_sendable(param: torch.Tensor) torch.Tensor#
Return a standard-dtype tensor suitable for wire transmission.
Quantized parameter types (e.g., Transformer Engine MXFP8Tensor) are dequantized to their original precision (usually BF16). Standard parameters are returned via
.data(unwrapped from autograd).
- class core.resharding.transforms.MXFP8ReshardTransform(
- convertible_params: set[str],
- persistent_buffers: dict,
- buffer_key_prefix: str = '',
- convert_on_send: bool = False,
Bases:
core.resharding.transforms.ReshardTransformMXFP8 format-conversion transform for reshard.
Writes received weight data directly into persistent
MXFP8Tensorbuffers so that CUDA-graph device-pointer captures remain valid across refits.Two modes are supported, controlled by convert_on_send:
convert_on_send=False(default — receiver-side conversion): The sender transmits plain BF16 (one tensor per op, identical to the default reshard path). The receiver allocates a BF16 receive buffer, thenfinalize_recvconverts BF16 → MXFP8 and writes into the persistent buffers. Because the wire format is unchanged the sender does not need a transform — only the receiver creates one. This is the simplest mode and avoids any sender/receiver coordination.convert_on_send=True(sender-side conversion): The sender converts each BF16 slice to MXFP8 and sends two tensors (data + scale) per op. The receiver allocates matching MXFP8 buffers andfinalize_recvcopies them directly. Both sender and receiver must use the transform so that tensor counts match. This mode halves wire bandwidth (~1 byte/elem vs 2).**Caveat**: CopyService backends that match local (same-rank) transfers by ``task_id`` (Gloo, NVSHMEM) will break if multiple tensors share the same ``task_id``. This mode is therefore only safe for non-colocated setups where sender and receiver are on different ranks. A future fix could generate unique sub-IDs.
- Parameters:
convertible_params – set of fully-qualified parameter names that should use this transform.
persistent_buffers – dict mapping parameter names (without buffer_key_prefix) to
MXFP8Tensorobjects that hold the receiver’s persistent data/scale storage. Empty on the sender when usingconvert_on_send=True.buffer_key_prefix – prefix to strip from
param_namewhen looking up entries in persistent_buffers (e.g."decoder.").convert_on_send – if True, convert BF16 → MXFP8 on the sender and transmit two tensors (data + scale). If False (default), transmit BF16 and convert on the receiver in
finalize_recv.
Initialization
- should_transform(param_name: str) bool#
- prepare_send(param_name, src_slice, src_param)#
- prepare_recv(param_name, dst_slice)#
- finalize_recv(param_name, dst_slice, recv_buffers)#