core.optimizer.param_layout#

Parameter layout dataclasses for optimizer-driven buffer layout.

These dataclasses describe how parameters are laid out in contiguous buffers. Each distributed optimizer implementation (e.g., DistributedOptimizer) is responsible for computing these layouts via a _compute_per_buffer_param_layout method, applying its own padding, alignment, and bucket splitting rules. DDP and buffers consume the resulting layouts without any optimizer-specific knowledge.

Module Contents#

Classes#

BufferKey

Identifies a distinct parameter buffer.

PerBufferParamLayout

Layout for parameters within a single contiguous buffer.

FullParamLayout

Layout for all parameters across all buffer groups in a model chunk.

Functions#

pad_to_divisor

Round up value to the nearest multiple of divisor.

pad_param_start

Align parameter start index to a 64-element boundary.

pad_bucket_end

Pad bucket end for DP-divisibility (and optionally high NCCL bus bandwidth).

API#

core.optimizer.param_layout.pad_to_divisor(value: int, divisor: int) int#

Round up value to the nearest multiple of divisor.

core.optimizer.param_layout.pad_param_start(param_start_index: int) int#

Align parameter start index to a 64-element boundary.

core.optimizer.param_layout.pad_bucket_end(
bucket_end_index: int,
data_parallel_world_size: int,
pad_for_high_nccl_busbw: bool,
) int#

Pad bucket end for DP-divisibility (and optionally high NCCL bus bandwidth).

class core.optimizer.param_layout.BufferKey#

Identifies a distinct parameter buffer.

Each unique combination of these fields corresponds to a separate contiguous buffer in DDP. Parameters are grouped into buffers by these dimensions.

.. attribute:: param_dtype

Storage dtype (torch.uint8 for FP8/NVFP4 parameters, else param.dtype).

.. attribute:: grad_dtype

Gradient reduction dtype.

.. attribute:: is_expert_parallel

Whether the buffer holds expert-parallel parameters, which use a separate data-parallel group.

param_dtype: torch.dtype#

None

grad_dtype: torch.dtype#

None

is_expert_parallel: bool#

None

class core.optimizer.param_layout.PerBufferParamLayout#

Layout for parameters within a single contiguous buffer.

Describes how parameters should be laid out in the contiguous buffer.

.. attribute:: param_index_map

Mapping from parameter to (start_index, end_index, bucket_id) in buffer.

.. attribute:: bucket_indices

List of (start_index, end_index) for each bucket.

.. attribute:: per_bucket_numel_unpadded

Number of unpadded elements per bucket.

.. attribute:: param_indices

The index of each param among same-dtype params (using the “fake” high-precision dtype for FP8/NVFP4 params). Needed for loading non-native-fp8 checkpoints in native-fp8 mode. Order matches param_index_map iteration order.

param_index_map: Dict[torch.nn.Parameter, Tuple[int, int, int]]#

‘field(…)’

bucket_indices: List[Tuple[int, int]]#

‘field(…)’

per_bucket_numel_unpadded: List[int]#

‘field(…)’

param_indices: List[int]#

‘field(…)’

class core.optimizer.param_layout.FullParamLayout#

Layout for all parameters across all buffer groups in a model chunk.

Maps BufferKey to per-buffer PerBufferParamLayout objects. Each PerBufferParamLayout has its own independent index space since different buffer groups are physically separate buffers.

.. attribute:: layouts

Mapping from BufferKey to PerBufferParamLayout.

layouts: Dict[core.optimizer.param_layout.BufferKey, core.optimizer.param_layout.PerBufferParamLayout]#

‘field(…)’