core.models.mimo.optimizer#
Optimizer for MIMO models with heterogeneous parallelism.
Module Contents#
Classes#
Optimizer info for a single module. |
|
Optimizer for MimoModel with heterogeneous parallelism. |
Functions#
Yield (sub_state_dict, inner_optimizer) pairs. |
|
Save: extract param_groups from optimizer sub-dict into a ShardedObject. |
|
Save: extract grad_scaler into a ShardedObject. |
|
Load: restore param_groups with current param IDs from the inner optimizer. |
|
Load: restore grad_scaler from ShardedObject key. |
|
Build replica_id tuple for ShardedObject deduplication. |
|
Create ProcessGroupCollection from HyperCommGrid for optimizer use. |
|
Create optimizer for MimoModel with heterogeneous parallelism. |
API#
- class core.models.mimo.optimizer.ModuleOptimizerInfo#
Optimizer info for a single module.
- optimizer: Optional[megatron.core.optimizer.optimizer.MegatronOptimizer]#
None
- grid: Optional[megatron.core.hyper_comm_grid.HyperCommGrid]#
None
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection]#
None
- is_active: bool#
None
- class core.models.mimo.optimizer.MimoOptimizer(
- module_infos: Dict[str, core.models.mimo.optimizer.ModuleOptimizerInfo],
- config: megatron.core.optimizer.optimizer_config.OptimizerConfig,
Bases:
megatron.core.optimizer.optimizer.MegatronOptimizerOptimizer for MimoModel with heterogeneous parallelism.
Each module gets its own optimizer. Global gradient norm is computed across all modules via all_reduce MAX.
Initialization
Input optimizer is the base optimizer (e.g., Adam).
- prepare_grads() bool#
- get_grad_norm() float#
Compute global gradient norm across all modules via all_reduce MAX.
- step() Tuple[bool, Optional[float], Optional[int]]#
- step_with_ready_grads() bool#
- zero_grad(set_to_none: bool = True)#
- get_loss_scale() torch.Tensor#
- count_zeros() int#
- property param_groups: List[dict]#
Combined param groups from all active module optimizers.
- state_dict()#
- load_state_dict(state_dict: Dict)#
Load per-module optimizer state dicts.
Reassembles param_groups and grad_scaler that were extracted and saved as ShardedObjects by sharded_state_dict(), then delegates to each per-module optimizer’s load_state_dict.
- sharded_state_dict(
- model_sharded_state_dict,
- is_loading: bool = False,
- **kwargs,
Build sharded state dict, routing param_groups and grad_scaler through distributed save as ShardedObjects (common.pt is rank-0 only, which misses LLM optimizer state in non-colocated mode).
- reload_model_params(state_dict=None)#
- core.models.mimo.optimizer._iter_optimizer_sub_dicts(module_sd, optimizer)#
Yield (sub_state_dict, inner_optimizer) pairs.
For a single optimizer, yields (module_sd, optimizer) once. For ChainedOptimizer with N>1 inner optimizers, yields (module_sd[i], chained_optimizers[i]) for each.
- core.models.mimo.optimizer._extract_param_groups(sub_sd, module_name, suffix, replica_id)#
Save: extract param_groups from optimizer sub-dict into a ShardedObject.
- core.models.mimo.optimizer._extract_grad_scaler(sub_sd, module_name, suffix, replica_id)#
Save: extract grad_scaler into a ShardedObject.
- core.models.mimo.optimizer._restore_param_groups(sub_sd, inner_optimizer, module_name)#
Load: restore param_groups with current param IDs from the inner optimizer.
- core.models.mimo.optimizer._restore_grad_scaler(sub_sd)#
Load: restore grad_scaler from ShardedObject key.
- core.models.mimo.optimizer._get_replica_id(
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection],
Build replica_id tuple for ShardedObject deduplication.
Includes pp_rank so only one PP stage writes the metadata, and dp_rank so only dp_rank=0 writes (others are replicas).
- core.models.mimo.optimizer._get_pg_collection_for_optimizer(
- grid,
Create ProcessGroupCollection from HyperCommGrid for optimizer use.
Only fetches process groups required by the optimizer. Assumes all groups are pre-created in the grid via grid.create_pg() - does not create any new groups.
The following groups must be pre-created in the grid before calling this function: grid.create_pg([“dp”]) grid.create_pg([“dp”, “cp”]) grid.create_pg([“tp”]) grid.create_pg([“pp”]) grid.create_pg([“tp”, “pp”]) grid.create_pg([“tp”, “ep”, “pp”]) grid.create_pg([“dp”, “ep”]) grid.create_pg([“tp”, “cp”, “ep”, “pp”, “dp”])
- Parameters:
grid – HyperCommGrid with pre-created process groups.
- Returns:
dp: Data parallel group
dp_cp: Data parallel with context parallel
tp: Tensor parallel group
mp: Model parallel group (tp × pp)
tp_ep_pp: Expert tensor-model-pipeline group
expt_dp: Expert data parallel group
- Return type:
ProcessGroupCollection containing optimizer-required groups
- core.models.mimo.optimizer.get_mimo_optimizer(
- mimo_model: MimoModel,
- config: megatron.core.optimizer.optimizer_config.OptimizerConfig,
Create optimizer for MimoModel with heterogeneous parallelism.