core.optimizer.layer_wise_optimizer#
Module Contents#
Classes#
Layer-wise distributed optimizer for Megatron-core models. |
Data#
API#
- core.optimizer.layer_wise_optimizer.logger#
‘getLogger(…)’
- class core.optimizer.layer_wise_optimizer.LayerWiseDistributedOptimizer(
- optimizers: List[core.optimizer.optimizer.MegatronOptimizer],
- config: core.optimizer.optimizer_config.OptimizerConfig,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
- init_state_fn_list: Optional[List[Callable]] = None,
Bases:
core.optimizer.optimizer.ChainedOptimizerLayer-wise distributed optimizer for Megatron-core models.
Experimental distributed optimizer wrapper that distributes weight to DP ranks by layer. Implemented as ChainedOptimizer to support multiple optimizers (e.g. muon + adamW) When using, keep all megatron distributed-optimizer related options OFF.
How LayerWiseDistributedOptimizer work:
weights are splited into lists and each rank only keep its shard in its optimizer
Megatron DDP handle allreduce grad, note that each rank have full model and grad
optimizer is already modified so only param belong to this DP rank is updated
grad_norm and zero counting will reduce metrics globally in step function
Do regular update with chained optimizers, modified optimizer only update shard
allgather updated params to every rank
Initialization
Initialize LayerWiseDistributedOptimizer.
- Parameters:
optimizers – List of MegatronOptimizers.
config – OptimizerConfig.
pg_collection – ProcessGroupCollection.
init_state_fn_list – List of init state functions.
- shard_params(optimizers)#
Shard all params into lists by rank.
- allgather_params() None#
All-gather updated params from all ranks.
- broadcast_params()#
All rank broadcast updated local params.
- get_grad_norm()#
- count_zeros()#
- step()#
step function for layer-wise optimizer.
- load_state_dict(state_dict)#
- sharded_state_dict(
- model_sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- is_loading: bool = False,
- **kwargs,
Sharded state dict for torch_dist format checkpointing. For fixed DP usage only, set replica_id to 0 for all ShardedTensor.
- save_state_dict_to_file(filename: str) None#
Save the parameter state of the optimizer. For torch format only.
- Parameters:
filename – The filename to save the parameter state.
- load_state_dict_from_file(filename: str) None#
Load the parameter state of the optimizer. For torch format only.