core.optimizer.emerging_optimizers#

Emerging optimizer registry.

To add a new emerging optimizer:

  1. Define its optimizer class (or import it).

  2. Write its _<name>_init_state_fn and _<name>_config_to_kwargs.

  3. Add an EmergingOptimizerEntry to _EMERGING_OPTIMIZERS at the bottom.

Module Contents#

Classes#

EmergingOptimizerEntry

Everything needed to create and configure an emerging optimizer.

TensorParallelMuon

Tensor Parallel Muon optimizer.

TensorParallelAdaptiveMuon

Tensor Parallel Adaptive Muon optimizer.

Functions#

get_supported_coefficient_types

Return the coefficient types supported by the installed emerging_optimizers.

validate_coefficient_type

Raise ValueError if coefficient_type is not supported.

_eopt_init_state_fn

Initialize emerging optimizer state for torch_dist checkpoint format.

_default_param_overrides_factory

Default param overrides: route non-linear/embedding params to Adam.

_create_emerging_optimizer

Instantiate an emerging optimizer and return it with its init_state_fn.

_is_nonlinear_or_embedding

True for parameters that should NOT use the emerging optimizer.

_get_qkv_split_shapes

Compute QKV split shapes from model config.

_kwargs_from_config

Match optimizer_cls.__init__ parameters to config attributes.

_muon_config_to_kwargs

Convert OptimizerConfig to TensorParallelMuon constructor kwargs.

_adaptive_muon_config_to_kwargs

Convert OptimizerConfig to TensorParallelAdaptiveMuon constructor kwargs.

_default_adam_based_eopt_config_to_kwargs

Convert OptimizerConfig to default emerging optimizer constructor kwargs.

Data#

API#

core.optimizer.emerging_optimizers.logger#

‘getLogger(…)’

core.optimizer.emerging_optimizers.get_supported_coefficient_types() tuple[str, ...]#

Return the coefficient types supported by the installed emerging_optimizers.

Reads the members of the NSCoeffT Literal type so that new types added upstream are automatically available without code changes here.

core.optimizer.emerging_optimizers.validate_coefficient_type(coefficient_type: str) None#

Raise ValueError if coefficient_type is not supported.

core.optimizer.emerging_optimizers._eopt_init_state_fn(opt, config=None)#

Initialize emerging optimizer state for torch_dist checkpoint format.

core.optimizer.emerging_optimizers._default_param_overrides_factory() Dict[core.optimizer.optimizer_config.ParamKey, Dict[str, Any]]#

Default param overrides: route non-linear/embedding params to Adam.

class core.optimizer.emerging_optimizers.EmergingOptimizerEntry#

Everything needed to create and configure an emerging optimizer.

.. attribute:: optimizer_cls

The torch optimizer class.

.. attribute:: init_state_fn

Lazily initialises optimizer state (needed for checkpoint formats).

.. attribute:: config_to_kwargs

(config, model_chunks, pg_collection) -> dict of constructor kwargs.

.. attribute:: default_param_overrides

Per-parameter config overrides applied automatically (e.g. route non-linear params to Adam).

optimizer_cls: type#

None

init_state_fn: Callable#

None

config_to_kwargs: Callable | None#

None

default_param_overrides: Dict[core.optimizer.optimizer_config.ParamKey, Dict[str, Any]]#

‘field(…)’

core.optimizer.emerging_optimizers._create_emerging_optimizer(
config,
param_groups,
eopt_name,
model_chunks,
pg_collection,
)#

Instantiate an emerging optimizer and return it with its init_state_fn.

core.optimizer.emerging_optimizers._is_nonlinear_or_embedding(param)#

True for parameters that should NOT use the emerging optimizer.

core.optimizer.emerging_optimizers._get_qkv_split_shapes(model_cfg) List[int]#

Compute QKV split shapes from model config.

core.optimizer.emerging_optimizers._EMERGING_OPTIMIZERS: Dict[str, core.optimizer.emerging_optimizers.EmergingOptimizerEntry]#

None

class core.optimizer.emerging_optimizers.TensorParallelMuon(
params: torch.optim.optimizer.ParamsT,
lr: float = 0.0003,
momentum: float = 0.95,
nesterov: bool = True,
weight_decay: float = 0.01,
use_decoupled_weight_decay: bool = True,
split_qkv: bool = False,
is_qkv_fn: Callable[[torch.Tensor], bool] | None = None,
qkv_split_shapes: tuple[int, int, int] | None = None,
fp32_matmul_prec: str = 'medium',
coefficient_type: str = 'quintic',
num_ns_steps: int = 5,
scale_mode: str = 'spectral',
extra_scale_factor: float = 1.0,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
tp_mode: Literal[blockwise, duplicated, distributed] = 'duplicated',
)#

Bases: emerging_optimizers.orthogonalized_optimizers.OrthogonalizedOptimizer

Tensor Parallel Muon optimizer.

Initialization

orthogonalize(
p: torch.Tensor,
grad: torch.Tensor,
**kwargs: Any,
) torch.Tensor#

Orthogonalize the momentum.

Parameters:
  • p – The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of information is only available in the param tensor, attributes for example.

  • grad – The momentum tensor.

Returns:

The orthogonalized gradient tensor.

class core.optimizer.emerging_optimizers.TensorParallelAdaptiveMuon(
params: torch.optim.optimizer.ParamsT,
lr: float = 0.0003,
momentum: float = 0.95,
nesterov: bool = True,
weight_decay: float = 0.01,
use_decoupled_weight_decay: bool = True,
split_qkv: bool = False,
is_qkv_fn: Callable[[torch.Tensor], bool] | None = None,
qkv_split_shapes: tuple[int, int, int] | None = None,
fp32_matmul_prec: str = 'medium',
coefficient_type: str = 'quintic',
num_ns_steps: int = 5,
scale_mode: str = 'spectral',
extra_scale_factor: float = 1.0,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
tp_mode: Literal[blockwise, duplicated, distributed] = 'duplicated',
moment2_method: Literal[adamuon, normuon] = 'adamuon',
beta2: float = 0.95,
eps: float = 1e-08,
)#

Bases: core.optimizer.emerging_optimizers.TensorParallelMuon, emerging_optimizers.orthogonalized_optimizers.AdaptiveMuon

Tensor Parallel Adaptive Muon optimizer.

This class extends Muon by adding AdamW-style or NorMuon-style second moment accumulation after orthogonalization. This idea was first explored in D.E. Carlson, E. Collins, Ya-Ping Hsieh, L. Carin, and V. Cevher. Preconditioned spectral descent for deep learning. In Advances in neural information processing systems 28 (2015). The step() method is overridden to include second moment normalization logic.

Parameters:
  • params – Iterable of parameters to optimize or dicts defining parameter groups.

  • lr – Learning rate.

  • momentum – The exponential decay rate for momentum.

  • nesterov – Whether to use Nesterov momentum.

  • weight_decay – Weight decay coefficient.

  • use_decoupled_weight_decay – Whether to use decoupled weight decay.

  • split_qkv – Whether to split QKV weights for orthogonalization.

  • is_qkv_fn – Function to determine if a tensor is a QKV weight.

  • qkv_split_shapes – Shapes for splitting QKV weights.

  • fp32_matmul_prec – Precision for FP32 matrix multiplication.

  • coefficient_type – The type of coefficient set to use for the Newton-Schulz iteration.

  • num_ns_steps – The number of iteration steps to use in the Newton-Schulz iteration.

  • scale_mode – The type of scale factor to use for the update.

  • extra_scale_factor – The additional scale factor to use for the update.

  • pg_collection – Process group collection for distributed training.

  • tp_mode – Tensor parallel mode (“blockwise”, “duplicated”, or “distributed”).

  • moment2_method – Method for second moment accumulation (“adamuon” or “normuon”).

  • beta2 – The exponential decay rate for second moment.

  • eps – Small constant for numerical stability.

Initialization

step(
closure: Optional[Callable] = None,
) Optional[float]#

Step function

core.optimizer.emerging_optimizers._kwargs_from_config(
optimizer_cls: type,
prefix: str,
config,
) Dict[str, Any]#

Match optimizer_cls.__init__ parameters to config attributes.

For each init parameter, looks for {prefix}_{name} on config first, then falls back to {name} (unprefixed). self and params are always skipped.

core.optimizer.emerging_optimizers._muon_config_to_kwargs(
config,
model_chunks,
pg_collection,
) Dict[str, Any]#

Convert OptimizerConfig to TensorParallelMuon constructor kwargs.

core.optimizer.emerging_optimizers._adaptive_muon_config_to_kwargs(
config,
model_chunks,
pg_collection,
) Dict[str, Any]#

Convert OptimizerConfig to TensorParallelAdaptiveMuon constructor kwargs.

core.optimizer.emerging_optimizers._default_adam_based_eopt_config_to_kwargs(
eopt_name,
config,
model_chunks,
pg_collection,
) Dict[str, Any]#

Convert OptimizerConfig to default emerging optimizer constructor kwargs.