core.optimizer.layer_wise_optimizer#
Module Contents#
Classes#
Layer-wise distributed optimizer for Megatron-core models. |
Functions#
Whether a parameter is managed by :class: |
|
Whether a DDP bucket belongs to a LayerWise-managed buffer. |
|
Tag every requires-grad param with |
Data#
API#
- core.optimizer.layer_wise_optimizer.logger#
‘getLogger(…)’
- core.optimizer.layer_wise_optimizer.is_managed_by_layer_wise_optimizer(param: torch.nn.Parameter) bool#
Whether a parameter is managed by :class:
LayerWiseDistributedOptimizer.Returns True for the 2D matrix-like weight parameters that Muon orthogonalizes via Newton-Schulz, and False for embeddings, biases, LayerNorm weights, and any other non-matrix parameter (which are handled by Adam through a separate
- Class:
DistributedOptimizer).
Mirrors the routing rule applied by
_get_param_groups/default_param_overridesfor Muon.
- core.optimizer.layer_wise_optimizer._bucket_is_managed_by_layer_wise_optimizer(
- bucket,
- default_for_untagged: bool = True,
Whether a DDP bucket belongs to a LayerWise-managed buffer.
Buckets are built from params that share a :class:
BufferKey, so checking the first param’s tag is sufficient.default_for_untaggedcontrols the legacy (no-tagging) case: callers asking “is this mine?” from the LayerWise side passTrue(legacy LayerWise owns everything); callers asking from the DistOpt side passFalse(legacy DistOpt also owns everything, so untagged buckets are not LayerWise-managed).
- core.optimizer.layer_wise_optimizer.tag_params_for_buffer_routing(model_chunks) None#
Tag every requires-grad param with
is_managed_by_layer_wise_optimizer.Run this once on the un-DDP-wrapped model chunks before
- Class:
DistributedDataParallelconstructs its grad/param buffers — the grouping functiongroup_params_for_buffersreads this attribute to decide which buffer each param lands in (LayerWise shard-aligned buffer vs DistOpt-style byte-level buffer).
- class core.optimizer.layer_wise_optimizer.LayerWiseDistributedOptimizer(
- optimizers: List[core.optimizer.optimizer.MegatronOptimizer],
- config: core.optimizer.optimizer_config.OptimizerConfig,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
- init_state_fn_list: Optional[List[Callable]] = None,
- model_chunks: Optional[List] = None,
Bases:
core.optimizer.optimizer.ChainedOptimizerLayer-wise distributed optimizer for Megatron-core models.
Experimental distributed optimizer wrapper that distributes weight to DP ranks by layer. Implemented as ChainedOptimizer to support multiple optimizers (e.g. muon + adamW) When using, keep all megatron distributed-optimizer related options OFF.
How LayerWiseDistributedOptimizer work:
weights are splited into lists and each rank only keep its shard in its optimizer
Megatron DDP handle allreduce grad, note that each rank have full model and grad
optimizer is already modified so only param belong to this DP rank is updated
grad_norm and zero counting will reduce metrics globally in step function
Do regular update with chained optimizers, modified optimizer only update shard
allgather updated params to every rank
Initialization
Initialize LayerWiseDistributedOptimizer.
- Parameters:
optimizers – List of MegatronOptimizers.
config – OptimizerConfig.
pg_collection – ProcessGroupCollection.
init_state_fn_list – List of init state functions.
model_chunks – DDP-wrapped model chunks.
- static _shard_divisor(data_parallel_world_size: int, ddp_config) int#
Per-shard alignment divisor.
Guarantees that
dp_size * shard_sizesatisfies bucket-end alignment and that every shard start is 64-element aligned (required by- Func:
pad_param_start).
- static _compute_per_buffer_param_layout(
- params: List[torch.nn.Parameter],
- bucket_size: Optional[int],
- data_parallel_world_size: int,
- ddp_config,
- param_indices: Optional[List[int]] = None,
Compute parameter layout with shard-aligned buckets via LPT bin-packing.
Assigns parameters to
dp_sizeshards within each bucket so that no parameter is split across a shard boundary, while keeping each bucket confined to a contiguous range in backprop order.Algorithm (operates in reverse model / backprop order):
Walk parameters in backprop order, accumulating them into a chunk. A shared (tied) embedding triggers an immediate finalisation followed by an isolated bucket for that embedding alone.
When the chunk’s total numel reaches
bucket_size(or all params have been consumed), bin-pack the chunk intodp_sizeshards via greedy LPT — sort by numel descending and assign each param to the shard with the smallest current load.Pad each shard to
max(shard_cursors)aligned to- meth:
_shard_divisor, then emit the bucket.
Each bucket therefore spans a contiguous backprop range so that
overlap_grad_reducecan dispatch the bucket’s reduce-scatter as soon as the bucket’s backward segment finishes — preserving the original DDP overlap semantics. LPT bin-packing keeps shards close to balanced; for uniform transformer blocks where ``params_per_layernum_layers
is a multiple ofdp_size`` the packing is perfect.
- Parameters:
params – Parameters in model-definition (forward) order.
bucket_size – Approximate elements per bucket (
None→ single bucket).data_parallel_world_size – Size of the data-parallel group.
ddp_config –
- class:
DistributedDataParallelConfig.
param_indices – Optional per-param dtype indices (passed through).
- Returns:
- class:
PerBufferParamLayoutwith shard-aligned buckets.
- static compute_full_param_layout(
- params: List[torch.nn.Parameter],
- bucket_size: Optional[int],
- data_parallel_world_size: int,
- ddp_config,
- expert_data_parallel_world_size: Optional[int] = None,
Compute parameter layouts for all buffer groups.
Groups parameters by :class:
BufferKeyvia :func:group_params_for_buffersand produces a layerwise shard-aligned size-matching layout per buffer. Every parameter stays within a single shard so the local optimizer step (e.g. Newton-Schulz iteration for Muon) can run on whole tensors.- Parameters:
params – All parameters to lay out.
bucket_size – Approximate elements per bucket (
None→ single bucket).data_parallel_world_size – DP group size for dense parameters.
ddp_config –
- class:
DistributedDataParallelConfig.
expert_data_parallel_world_size – Expert DP group size (defaults to
data_parallel_world_size).
- Returns:
- class:
FullParamLayoutwith a :class:PerBufferParamLayoutper buffer group.
- shard_params(optimizers, full_param_layouts=None)#
Shard params across ranks according to the computed param layout.
Each param’s shard assignment is derived from the :class:
FullParamLayoutstored on the DDP model chunks. Within each bucket the buffer is divided intodp_sizeequal shards; a param’s shard index is determined by its position in the buffer.Falls back to the legacy ping-pong-by-numel strategy when no layout is available (e.g.
dp_size == 1or no DDP wrapper).- Parameters:
optimizers – Optimizers whose param groups will be narrowed to the local rank’s shard.
full_param_layouts – List of :class:
FullParamLayout(one per model chunk).Nonetriggers the legacy fallback.
- _shard_params_from_layout(
- optimizers,
- full_param_layouts,
- dp_cp_size,
- expt_dp_size,
Derive shard assignments from the param layout.
- _shard_params_ping_pong(optimizers, dp_cp_size, expt_dp_size)#
Legacy ping-pong-by-numel shard assignment (no layout available).
Legacy: this method is a fallback for when no
full_param_layoutis provided. Once all call sites supply a layout, this can be removed in favor of :meth:_shard_params_from_layout.List of parameters are sorted by numel and assigned to ranks in ping-pong style. Example of 4 ranks and 10 parameters p0-p9 after sorting, then dp_cp_params_list will be [[p0, p7, p8], [p1, p6, p9], [p2, p5], [p3, p4]].
- set_bucket_layerwise_params_list(model_chunks)#
Map sharded params to DDP buckets for async all-gather.
Legacy: only used by the variable-size all-gather path (
use_buffer_param_sync=False). Once all call sites supply afull_param_layout, this can be removed — the standard distributed optimizer buffer all-gather handles param sync without per-bucket param lists.For each bucket in each model chunk’s bucket groups, build per-rank param lists by cross-referencing the layer-wise sharded param lists with the bucket’s params.
- Parameters:
model_chunks – DDP-wrapped model chunks with bucket_groups.
- allgather_params() None#
All-gather updated params from all ranks.
Legacy: only used when
use_buffer_param_sync=False. Once all call sites supply afull_param_layout, this can be removed — the standard distributed optimizer buffer all-gather (viastart_param_sync) replaces this flatten/unflatten path.
- broadcast_params()#
All rank broadcast updated local params.
- get_grad_norm()#
- has_grad_norm_group(grad_norm_group: str) bool#
Whether any global rank owns params for a registered grad-norm group.
Overrides ChainedOptimizer to use a single global all-reduce (group=None), matching the scope of get_grad_norm and _get_grad_norm_for_group which also reduce globally. All LayerWise grad-stats reductions are global (identical to DistributedOptimizer’s pattern), so the existence check must be too — using a per-sub-optimizer group here would create a collective mismatch.
- _get_grad_norm_for_group(grad_norm_group: str)#
- count_zeros()#
- start_param_sync_for_bucket_group_subset() None#
Trigger
start_param_syncon LayerWise-managed bucket groups only.Walks each model chunk’s dense + expert-parallel bucket groups and skips any group not managed by LayerWise, so a sibling
- Class:
DistributedOptimizer’s ownstart_param_synccall does not double-sync the same buckets. Uses- Meth:
DistributedDataParallel._start_bucket_group_param_syncso FP8 post-all-gather processing (and MXFP8 copy) still runs.
- step_with_ready_grads() bool#
Step then all-gather LayerWise-managed param buffers.
Placed on
step_with_ready_grads(notstep) so the param sync also runs when this optimizer is a child of an outerChainedOptimizer, which callsstep_with_ready_gradsdirectly on each child and bypassesstep.
- load_state_dict(state_dict)#
- sharded_state_dict(
- model_sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- is_loading: bool = False,
- **kwargs,
Sharded state dict for torch_dist format checkpointing. For fixed DP usage only, set replica_id to 0 for all ShardedTensor.
- save_state_dict_to_file(filename: str) None#
Save the parameter state of the optimizer. For torch format only.
- Parameters:
filename – The filename to save the parameter state.
- load_state_dict_from_file(filename: str) None#
Load the parameter state of the optimizer. For torch format only.