core.resharding.transforms#

Module Contents#

Classes#

ReshardTransform

Hook for custom send/recv/writeback during reshard execution.

MXFP8ReshardTransform

MXFP8 format-conversion transform for reshard.

Functions#

_scale_slice_from_data_slice

Convert an MXFP8 data slice to the corresponding scale slice.

_ensure_sendable

Return a standard-dtype tensor suitable for wire transmission.

API#

class core.resharding.transforms.ReshardTransform#

Hook for custom send/recv/writeback during reshard execution.

Implementations override the four methods below. When an instance is passed to execute_reshard_plan, each TransferOp is checked via should_transform(param_name); if True the transform methods are used instead of the default send/recv/writeback logic.

The transform may change the wire format (e.g. send MXFP8 data+scale instead of BF16) or keep the same wire format and only post-process on the receive side (e.g. receive BF16, convert to MXFP8 in finalize_recv). The only constraint is that prepare_send and prepare_recv must return the same number of tensors for a given parameter so that send/recv pairs match.

should_transform(param_name: str) bool#

Return True if param_name should use the transform path.

abstractmethod prepare_send(
param_name: str,
src_slice: tuple[slice, ...],
src_param: torch.nn.Parameter,
) list[torch.Tensor]#

Produce tensor(s) to send for param_name.

May return multiple tensors (e.g. data + scale when converting to MXFP8 on the sender side). The default implementation sends the BF16 slice unchanged (single tensor).

abstractmethod prepare_recv(
param_name: str,
dst_slice: tuple[slice, ...],
) list[torch.Tensor]#

Allocate receive buffer(s). Count must match prepare_send output.

abstractmethod finalize_recv(
param_name: str,
dst_slice: tuple[slice, ...],
recv_buffers: list[torch.Tensor],
) None#

Write received data into final destination (e.g. persistent buffers).

This is where receiver-side format conversion can happen (e.g. converting a BF16 recv buffer to MXFP8 before writing into persistent storage).

core.resharding.transforms._scale_slice_from_data_slice(
data_slice: tuple[slice, ...],
block_size: int = 32,
) tuple[slice, ...]#

Convert an MXFP8 data slice to the corresponding scale slice.

In MXFP8, each group of block_size elements along the last (K) dimension shares a single scale value. All dimensions except the last are passed through unchanged; the last slice has its start/stop divided by block_size. Integer index on the last dim is converted to scale index as idx // block_size.

core.resharding.transforms._ensure_sendable(param: torch.Tensor) torch.Tensor#

Return a standard-dtype tensor suitable for wire transmission.

Quantized parameter types (e.g., Transformer Engine MXFP8Tensor) are dequantized to their original precision (usually BF16). Standard parameters are returned via .data (unwrapped from autograd).

class core.resharding.transforms.MXFP8ReshardTransform(
convertible_params: set[str],
persistent_buffers: dict,
buffer_key_prefix: str = '',
convert_on_send: bool = False,
)#

Bases: core.resharding.transforms.ReshardTransform

MXFP8 format-conversion transform for reshard.

Writes received weight data directly into persistent MXFP8Tensor buffers so that CUDA-graph device-pointer captures remain valid across refits.

Two modes are supported, controlled by convert_on_send:

convert_on_send=False (default — receiver-side conversion): The sender transmits plain BF16 (one tensor per op, identical to the default reshard path). The receiver allocates a BF16 receive buffer, then finalize_recv converts BF16 → MXFP8 and writes into the persistent buffers. Because the wire format is unchanged the sender does not need a transform — only the receiver creates one. This is the simplest mode and avoids any sender/receiver coordination.

convert_on_send=True (sender-side conversion): The sender converts each BF16 slice to MXFP8 and sends two tensors (data + scale) per op. The receiver allocates matching MXFP8 buffers and finalize_recv copies them directly. Both sender and receiver must use the transform so that tensor counts match. This mode halves wire bandwidth (~1 byte/elem vs 2).

**Caveat**: CopyService backends that match local (same-rank)
transfers by ``task_id`` (Gloo, NVSHMEM) will break if multiple
tensors share the same ``task_id``.  This mode is therefore only
safe for non-colocated setups where sender and receiver are on
different ranks.  A future fix could generate unique sub-IDs.
Parameters:
  • convertible_params – set of fully-qualified parameter names that should use this transform.

  • persistent_buffers – dict mapping parameter names (without buffer_key_prefix) to MXFP8Tensor objects that hold the receiver’s persistent data/scale storage. Empty on the sender when using convert_on_send=True.

  • buffer_key_prefix – prefix to strip from param_name when looking up entries in persistent_buffers (e.g. "decoder.").

  • convert_on_send – if True, convert BF16 → MXFP8 on the sender and transmit two tensors (data + scale). If False (default), transmit BF16 and convert on the receiver in finalize_recv.

Initialization

should_transform(param_name: str) bool#
prepare_send(param_name, src_slice, src_param)#
prepare_recv(param_name, dst_slice)#
finalize_recv(param_name, dst_slice, recv_buffers)#