core.optimizer.layer_wise_optimizer#

Module Contents#

Classes#

LayerWiseDistributedOptimizer

Layer-wise distributed optimizer for Megatron-core models.

Functions#

is_managed_by_layer_wise_optimizer

Whether a parameter is managed by :class:LayerWiseDistributedOptimizer.

_bucket_is_managed_by_layer_wise_optimizer

Whether a DDP bucket belongs to a LayerWise-managed buffer.

tag_params_for_buffer_routing

Tag every requires-grad param with is_managed_by_layer_wise_optimizer.

Data#

API#

core.optimizer.layer_wise_optimizer.logger#

‘getLogger(…)’

core.optimizer.layer_wise_optimizer.is_managed_by_layer_wise_optimizer(param: torch.nn.Parameter) bool#

Whether a parameter is managed by :class:LayerWiseDistributedOptimizer.

Returns True for the 2D matrix-like weight parameters that Muon orthogonalizes via Newton-Schulz, and False for embeddings, biases, LayerNorm weights, and any other non-matrix parameter (which are handled by Adam through a separate

Class:

DistributedOptimizer).

Mirrors the routing rule applied by _get_param_groups / default_param_overrides for Muon.

core.optimizer.layer_wise_optimizer._bucket_is_managed_by_layer_wise_optimizer(
bucket,
default_for_untagged: bool = True,
) bool#

Whether a DDP bucket belongs to a LayerWise-managed buffer.

Buckets are built from params that share a :class:BufferKey, so checking the first param’s tag is sufficient. default_for_untagged controls the legacy (no-tagging) case: callers asking “is this mine?” from the LayerWise side pass True (legacy LayerWise owns everything); callers asking from the DistOpt side pass False (legacy DistOpt also owns everything, so untagged buckets are not LayerWise-managed).

core.optimizer.layer_wise_optimizer.tag_params_for_buffer_routing(model_chunks) None#

Tag every requires-grad param with is_managed_by_layer_wise_optimizer.

Run this once on the un-DDP-wrapped model chunks before

Class:

DistributedDataParallel constructs its grad/param buffers — the grouping function group_params_for_buffers reads this attribute to decide which buffer each param lands in (LayerWise shard-aligned buffer vs DistOpt-style byte-level buffer).

class core.optimizer.layer_wise_optimizer.LayerWiseDistributedOptimizer(
optimizers: List[core.optimizer.optimizer.MegatronOptimizer],
config: core.optimizer.optimizer_config.OptimizerConfig,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
init_state_fn_list: Optional[List[Callable]] = None,
model_chunks: Optional[List] = None,
)#

Bases: core.optimizer.optimizer.ChainedOptimizer

Layer-wise distributed optimizer for Megatron-core models.

Experimental distributed optimizer wrapper that distributes weight to DP ranks by layer. Implemented as ChainedOptimizer to support multiple optimizers (e.g. muon + adamW) When using, keep all megatron distributed-optimizer related options OFF.

How LayerWiseDistributedOptimizer work:

  1. weights are splited into lists and each rank only keep its shard in its optimizer

  2. Megatron DDP handle allreduce grad, note that each rank have full model and grad

  3. optimizer is already modified so only param belong to this DP rank is updated

  4. grad_norm and zero counting will reduce metrics globally in step function

  5. Do regular update with chained optimizers, modified optimizer only update shard

  6. allgather updated params to every rank

Initialization

Initialize LayerWiseDistributedOptimizer.

Parameters:
  • optimizers – List of MegatronOptimizers.

  • config – OptimizerConfig.

  • pg_collection – ProcessGroupCollection.

  • init_state_fn_list – List of init state functions.

  • model_chunks – DDP-wrapped model chunks.

static _shard_divisor(data_parallel_world_size: int, ddp_config) int#

Per-shard alignment divisor.

Guarantees that dp_size * shard_size satisfies bucket-end alignment and that every shard start is 64-element aligned (required by

Func:

pad_param_start).

static _compute_per_buffer_param_layout(
params: List[torch.nn.Parameter],
bucket_size: Optional[int],
data_parallel_world_size: int,
ddp_config,
param_indices: Optional[List[int]] = None,
) core.optimizer.param_layout.PerBufferParamLayout#

Compute parameter layout with shard-aligned buckets via LPT bin-packing.

Assigns parameters to dp_size shards within each bucket so that no parameter is split across a shard boundary, while keeping each bucket confined to a contiguous range in backprop order.

Algorithm (operates in reverse model / backprop order):

  1. Walk parameters in backprop order, accumulating them into a chunk. A shared (tied) embedding triggers an immediate finalisation followed by an isolated bucket for that embedding alone.

  2. When the chunk’s total numel reaches bucket_size (or all params have been consumed), bin-pack the chunk into dp_size shards via greedy LPT — sort by numel descending and assign each param to the shard with the smallest current load.

  3. Pad each shard to max(shard_cursors) aligned to

    meth:

    _shard_divisor, then emit the bucket.

Each bucket therefore spans a contiguous backprop range so that overlap_grad_reduce can dispatch the bucket’s reduce-scatter as soon as the bucket’s backward segment finishes — preserving the original DDP overlap semantics. LPT bin-packing keeps shards close to balanced; for uniform transformer blocks where ``params_per_layer

  • num_layersis a multiple ofdp_size`` the packing is perfect.

Parameters:
  • params – Parameters in model-definition (forward) order.

  • bucket_size – Approximate elements per bucket (None → single bucket).

  • data_parallel_world_size – Size of the data-parallel group.

  • ddp_config

    class:

    DistributedDataParallelConfig.

  • param_indices – Optional per-param dtype indices (passed through).

Returns:

class:

PerBufferParamLayout with shard-aligned buckets.

static compute_full_param_layout(
params: List[torch.nn.Parameter],
bucket_size: Optional[int],
data_parallel_world_size: int,
ddp_config,
expert_data_parallel_world_size: Optional[int] = None,
) core.optimizer.param_layout.FullParamLayout#

Compute parameter layouts for all buffer groups.

Groups parameters by :class:BufferKey via :func:group_params_for_buffers and produces a layerwise shard-aligned size-matching layout per buffer. Every parameter stays within a single shard so the local optimizer step (e.g. Newton-Schulz iteration for Muon) can run on whole tensors.

Parameters:
  • params – All parameters to lay out.

  • bucket_size – Approximate elements per bucket (None → single bucket).

  • data_parallel_world_size – DP group size for dense parameters.

  • ddp_config

    class:

    DistributedDataParallelConfig.

  • expert_data_parallel_world_size – Expert DP group size (defaults to data_parallel_world_size).

Returns:

class:

FullParamLayout with a :class:PerBufferParamLayout per buffer group.

shard_params(optimizers, full_param_layouts=None)#

Shard params across ranks according to the computed param layout.

Each param’s shard assignment is derived from the :class:FullParamLayout stored on the DDP model chunks. Within each bucket the buffer is divided into dp_size equal shards; a param’s shard index is determined by its position in the buffer.

Falls back to the legacy ping-pong-by-numel strategy when no layout is available (e.g. dp_size == 1 or no DDP wrapper).

Parameters:
  • optimizers – Optimizers whose param groups will be narrowed to the local rank’s shard.

  • full_param_layouts – List of :class:FullParamLayout (one per model chunk). None triggers the legacy fallback.

_shard_params_from_layout(
optimizers,
full_param_layouts,
dp_cp_size,
expt_dp_size,
)#

Derive shard assignments from the param layout.

_shard_params_ping_pong(optimizers, dp_cp_size, expt_dp_size)#

Legacy ping-pong-by-numel shard assignment (no layout available).

Legacy: this method is a fallback for when no full_param_layout is provided. Once all call sites supply a layout, this can be removed in favor of :meth:_shard_params_from_layout.

List of parameters are sorted by numel and assigned to ranks in ping-pong style. Example of 4 ranks and 10 parameters p0-p9 after sorting, then dp_cp_params_list will be [[p0, p7, p8], [p1, p6, p9], [p2, p5], [p3, p4]].

set_bucket_layerwise_params_list(model_chunks)#

Map sharded params to DDP buckets for async all-gather.

Legacy: only used by the variable-size all-gather path (use_buffer_param_sync=False). Once all call sites supply a full_param_layout, this can be removed — the standard distributed optimizer buffer all-gather handles param sync without per-bucket param lists.

For each bucket in each model chunk’s bucket groups, build per-rank param lists by cross-referencing the layer-wise sharded param lists with the bucket’s params.

Parameters:

model_chunks – DDP-wrapped model chunks with bucket_groups.

allgather_params() None#

All-gather updated params from all ranks.

Legacy: only used when use_buffer_param_sync=False. Once all call sites supply a full_param_layout, this can be removed — the standard distributed optimizer buffer all-gather (via start_param_sync) replaces this flatten/unflatten path.

broadcast_params()#

All rank broadcast updated local params.

get_grad_norm()#
has_grad_norm_group(grad_norm_group: str) bool#

Whether any global rank owns params for a registered grad-norm group.

Overrides ChainedOptimizer to use a single global all-reduce (group=None), matching the scope of get_grad_norm and _get_grad_norm_for_group which also reduce globally. All LayerWise grad-stats reductions are global (identical to DistributedOptimizer’s pattern), so the existence check must be too — using a per-sub-optimizer group here would create a collective mismatch.

_get_grad_norm_for_group(grad_norm_group: str)#
count_zeros()#
start_param_sync_for_bucket_group_subset() None#

Trigger start_param_sync on LayerWise-managed bucket groups only.

Walks each model chunk’s dense + expert-parallel bucket groups and skips any group not managed by LayerWise, so a sibling

Class:

DistributedOptimizer’s own start_param_sync call does not double-sync the same buckets. Uses

Meth:

DistributedDataParallel._start_bucket_group_param_sync so FP8 post-all-gather processing (and MXFP8 copy) still runs.

step_with_ready_grads() bool#

Step then all-gather LayerWise-managed param buffers.

Placed on step_with_ready_grads (not step) so the param sync also runs when this optimizer is a child of an outer ChainedOptimizer, which calls step_with_ready_grads directly on each child and bypasses step.

load_state_dict(state_dict)#
sharded_state_dict(
model_sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
is_loading: bool = False,
**kwargs,
)#

Sharded state dict for torch_dist format checkpointing. For fixed DP usage only, set replica_id to 0 for all ShardedTensor.

save_state_dict_to_file(filename: str) None#

Save the parameter state of the optimizer. For torch format only.

Parameters:

filename – The filename to save the parameter state.

load_state_dict_from_file(filename: str) None#

Load the parameter state of the optimizer. For torch format only.