core.optimizer.layer_wise_optimizer#

Module Contents#

Classes#

LayerWiseDistributedOptimizer

Layer-wise distributed optimizer for Megatron-core models.

Data#

API#

core.optimizer.layer_wise_optimizer.logger#

‘getLogger(…)’

class core.optimizer.layer_wise_optimizer.LayerWiseDistributedOptimizer(
optimizers: List[core.optimizer.optimizer.MegatronOptimizer],
config: core.optimizer.optimizer_config.OptimizerConfig,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
init_state_fn_list: Optional[List[Callable]] = None,
)#

Bases: core.optimizer.optimizer.ChainedOptimizer

Layer-wise distributed optimizer for Megatron-core models.

Experimental distributed optimizer wrapper that distributes weight to DP ranks by layer. Implemented as ChainedOptimizer to support multiple optimizers (e.g. muon + adamW) When using, keep all megatron distributed-optimizer related options OFF.

How LayerWiseDistributedOptimizer work:

  1. weights are splited into lists and each rank only keep its shard in its optimizer

  2. Megatron DDP handle allreduce grad, note that each rank have full model and grad

  3. optimizer is already modified so only param belong to this DP rank is updated

  4. grad_norm and zero counting will reduce metrics globally in step function

  5. Do regular update with chained optimizers, modified optimizer only update shard

  6. allgather updated params to every rank

Initialization

Initialize LayerWiseDistributedOptimizer.

Parameters:
  • optimizers – List of MegatronOptimizers.

  • config – OptimizerConfig.

  • pg_collection – ProcessGroupCollection.

  • init_state_fn_list – List of init state functions.

shard_params(optimizers)#

Shard all params into lists by rank.

allgather_params() None#

All-gather updated params from all ranks.

broadcast_params()#

All rank broadcast updated local params.

get_grad_norm()#
count_zeros()#
step()#

step function for layer-wise optimizer.

load_state_dict(state_dict)#
sharded_state_dict(
model_sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
is_loading: bool = False,
**kwargs,
)#

Sharded state dict for torch_dist format checkpointing. For fixed DP usage only, set replica_id to 0 for all ShardedTensor.

save_state_dict_to_file(filename: str) None#

Save the parameter state of the optimizer. For torch format only.

Parameters:

filename – The filename to save the parameter state.

load_state_dict_from_file(filename: str) None#

Load the parameter state of the optimizer. For torch format only.