core.models.mimo.optimizer#

Optimizer for MIMO models with heterogeneous parallelism.

Module Contents#

Classes#

ModuleOptimizerInfo

Optimizer info for a single module.

MimoOptimizer

Optimizer for MimoModel with heterogeneous parallelism.

Functions#

_iter_optimizer_sub_dicts

Yield (sub_state_dict, inner_optimizer) pairs.

_extract_param_groups

Save: extract param_groups from optimizer sub-dict into a ShardedObject.

_extract_grad_scaler

Save: extract grad_scaler into a ShardedObject.

_restore_param_groups

Load: restore param_groups with current param IDs from the inner optimizer.

_restore_grad_scaler

Load: restore grad_scaler from ShardedObject key.

_get_replica_id

Build replica_id tuple for ShardedObject deduplication.

_get_pg_collection_for_optimizer

Create ProcessGroupCollection from HyperCommGrid for optimizer use.

get_mimo_optimizer

Create optimizer for MimoModel with heterogeneous parallelism.

API#

class core.models.mimo.optimizer.ModuleOptimizerInfo#

Optimizer info for a single module.

optimizer: Optional[megatron.core.optimizer.optimizer.MegatronOptimizer]#

None

grid: Optional[megatron.core.hyper_comm_grid.HyperCommGrid]#

None

pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection]#

None

is_active: bool#

None

class core.models.mimo.optimizer.MimoOptimizer(
module_infos: Dict[str, core.models.mimo.optimizer.ModuleOptimizerInfo],
config: megatron.core.optimizer.optimizer_config.OptimizerConfig,
)#

Bases: megatron.core.optimizer.optimizer.MegatronOptimizer

Optimizer for MimoModel with heterogeneous parallelism.

Each module gets its own optimizer. Global gradient norm is computed across all modules via all_reduce MAX.

Initialization

Input optimizer is the base optimizer (e.g., Adam).

prepare_grads() bool#
get_grad_norm() float#

Compute global gradient norm across all modules via all_reduce MAX.

step() Tuple[bool, Optional[float], Optional[int]]#
step_with_ready_grads() bool#
zero_grad(set_to_none: bool = True)#
get_loss_scale() torch.Tensor#
count_zeros() int#
property param_groups: List[dict]#

Combined param groups from all active module optimizers.

state_dict()#
load_state_dict(state_dict: Dict)#

Load per-module optimizer state dicts.

Reassembles param_groups and grad_scaler that were extracted and saved as ShardedObjects by sharded_state_dict(), then delegates to each per-module optimizer’s load_state_dict.

sharded_state_dict(
model_sharded_state_dict,
is_loading: bool = False,
**kwargs,
)#

Build sharded state dict, routing param_groups and grad_scaler through distributed save as ShardedObjects (common.pt is rank-0 only, which misses LLM optimizer state in non-colocated mode).

reload_model_params(state_dict=None)#
core.models.mimo.optimizer._iter_optimizer_sub_dicts(module_sd, optimizer)#

Yield (sub_state_dict, inner_optimizer) pairs.

For a single optimizer, yields (module_sd, optimizer) once. For ChainedOptimizer with N>1 inner optimizers, yields (module_sd[i], chained_optimizers[i]) for each.

core.models.mimo.optimizer._extract_param_groups(sub_sd, module_name, suffix, replica_id)#

Save: extract param_groups from optimizer sub-dict into a ShardedObject.

core.models.mimo.optimizer._extract_grad_scaler(sub_sd, module_name, suffix, replica_id)#

Save: extract grad_scaler into a ShardedObject.

core.models.mimo.optimizer._restore_param_groups(sub_sd, inner_optimizer, module_name)#

Load: restore param_groups with current param IDs from the inner optimizer.

core.models.mimo.optimizer._restore_grad_scaler(sub_sd)#

Load: restore grad_scaler from ShardedObject key.

core.models.mimo.optimizer._get_replica_id(
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection],
) tuple#

Build replica_id tuple for ShardedObject deduplication.

Includes pp_rank so only one PP stage writes the metadata, and dp_rank so only dp_rank=0 writes (others are replicas).

core.models.mimo.optimizer._get_pg_collection_for_optimizer(
grid,
) megatron.core.process_groups_config.ProcessGroupCollection#

Create ProcessGroupCollection from HyperCommGrid for optimizer use.

Only fetches process groups required by the optimizer. Assumes all groups are pre-created in the grid via grid.create_pg() - does not create any new groups.

The following groups must be pre-created in the grid before calling this function: grid.create_pg([“dp”]) grid.create_pg([“dp”, “cp”]) grid.create_pg([“tp”]) grid.create_pg([“pp”]) grid.create_pg([“tp”, “pp”]) grid.create_pg([“tp”, “ep”, “pp”]) grid.create_pg([“dp”, “ep”]) grid.create_pg([“tp”, “cp”, “ep”, “pp”, “dp”])

Parameters:

grid – HyperCommGrid with pre-created process groups.

Returns:

  • dp: Data parallel group

  • dp_cp: Data parallel with context parallel

  • tp: Tensor parallel group

  • mp: Model parallel group (tp × pp)

  • tp_ep_pp: Expert tensor-model-pipeline group

  • expt_dp: Expert data parallel group

Return type:

ProcessGroupCollection containing optimizer-required groups

core.models.mimo.optimizer.get_mimo_optimizer(
mimo_model: MimoModel,
config: megatron.core.optimizer.optimizer_config.OptimizerConfig,
) core.models.mimo.optimizer.MimoOptimizer#

Create optimizer for MimoModel with heterogeneous parallelism.