core.optimizer.emerging_optimizers#
Emerging optimizer registry.
To add a new emerging optimizer:
Define its optimizer class (or import it).
Write its
_<name>_init_state_fnand_<name>_config_to_kwargs.Add an
EmergingOptimizerEntryto_EMERGING_OPTIMIZERSat the bottom.
Module Contents#
Classes#
Everything needed to create and configure an emerging optimizer. |
|
Tensor Parallel Muon optimizer. |
|
Tensor Parallel Adaptive Muon optimizer. |
Functions#
Return the coefficient types supported by the installed emerging_optimizers. |
|
Raise |
|
Initialize emerging optimizer state for torch_dist checkpoint format. |
|
Default param overrides: route non-linear/embedding params to Adam. |
|
Instantiate an emerging optimizer and return it with its init_state_fn. |
|
True for parameters that should NOT use the emerging optimizer. |
|
Compute QKV split shapes from model config. |
|
Match |
|
Convert OptimizerConfig to TensorParallelMuon constructor kwargs. |
|
Convert OptimizerConfig to TensorParallelAdaptiveMuon constructor kwargs. |
|
Convert OptimizerConfig to default emerging optimizer constructor kwargs. |
Data#
API#
- core.optimizer.emerging_optimizers.logger#
‘getLogger(…)’
- core.optimizer.emerging_optimizers.get_supported_coefficient_types() tuple[str, ...]#
Return the coefficient types supported by the installed emerging_optimizers.
Reads the members of the
NSCoeffTLiteral type so that new types added upstream are automatically available without code changes here.
- core.optimizer.emerging_optimizers.validate_coefficient_type(coefficient_type: str) None#
Raise
ValueErrorif coefficient_type is not supported.
- core.optimizer.emerging_optimizers._eopt_init_state_fn(opt, config=None)#
Initialize emerging optimizer state for torch_dist checkpoint format.
- core.optimizer.emerging_optimizers._default_param_overrides_factory() Dict[core.optimizer.optimizer_config.ParamKey, Dict[str, Any]]#
Default param overrides: route non-linear/embedding params to Adam.
- class core.optimizer.emerging_optimizers.EmergingOptimizerEntry#
Everything needed to create and configure an emerging optimizer.
.. attribute:: optimizer_cls
The torch optimizer class.
.. attribute:: init_state_fn
Lazily initialises optimizer state (needed for checkpoint formats).
.. attribute:: config_to_kwargs
(config, model_chunks, pg_collection) -> dictof constructor kwargs... attribute:: default_param_overrides
Per-parameter config overrides applied automatically (e.g. route non-linear params to Adam).
- init_state_fn: Callable#
None
- config_to_kwargs: Callable | None#
None
- default_param_overrides: Dict[core.optimizer.optimizer_config.ParamKey, Dict[str, Any]]#
‘field(…)’
- core.optimizer.emerging_optimizers._create_emerging_optimizer(
- config,
- param_groups,
- eopt_name,
- model_chunks,
- pg_collection,
Instantiate an emerging optimizer and return it with its init_state_fn.
- core.optimizer.emerging_optimizers._is_nonlinear_or_embedding(param)#
True for parameters that should NOT use the emerging optimizer.
- core.optimizer.emerging_optimizers._get_qkv_split_shapes(model_cfg) List[int]#
Compute QKV split shapes from model config.
- core.optimizer.emerging_optimizers._EMERGING_OPTIMIZERS: Dict[str, core.optimizer.emerging_optimizers.EmergingOptimizerEntry]#
None
- class core.optimizer.emerging_optimizers.TensorParallelMuon(
- params: torch.optim.optimizer.ParamsT,
- lr: float = 0.0003,
- momentum: float = 0.95,
- nesterov: bool = True,
- weight_decay: float = 0.01,
- use_decoupled_weight_decay: bool = True,
- split_qkv: bool = False,
- is_qkv_fn: Callable[[torch.Tensor], bool] | None = None,
- qkv_split_shapes: tuple[int, int, int] | None = None,
- fp32_matmul_prec: str = 'medium',
- coefficient_type: str = 'quintic',
- num_ns_steps: int = 5,
- scale_mode: str = 'spectral',
- extra_scale_factor: float = 1.0,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
- tp_mode: Literal[blockwise, duplicated, distributed] = 'duplicated',
Bases:
emerging_optimizers.orthogonalized_optimizers.OrthogonalizedOptimizerTensor Parallel Muon optimizer.
Initialization
- orthogonalize(
- p: torch.Tensor,
- grad: torch.Tensor,
- **kwargs: Any,
Orthogonalize the momentum.
- Parameters:
p – The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of information is only available in the param tensor, attributes for example.
grad – The momentum tensor.
- Returns:
The orthogonalized gradient tensor.
- class core.optimizer.emerging_optimizers.TensorParallelAdaptiveMuon(
- params: torch.optim.optimizer.ParamsT,
- lr: float = 0.0003,
- momentum: float = 0.95,
- nesterov: bool = True,
- weight_decay: float = 0.01,
- use_decoupled_weight_decay: bool = True,
- split_qkv: bool = False,
- is_qkv_fn: Callable[[torch.Tensor], bool] | None = None,
- qkv_split_shapes: tuple[int, int, int] | None = None,
- fp32_matmul_prec: str = 'medium',
- coefficient_type: str = 'quintic',
- num_ns_steps: int = 5,
- scale_mode: str = 'spectral',
- extra_scale_factor: float = 1.0,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
- tp_mode: Literal[blockwise, duplicated, distributed] = 'duplicated',
- moment2_method: Literal[adamuon, normuon] = 'adamuon',
- beta2: float = 0.95,
- eps: float = 1e-08,
Bases:
core.optimizer.emerging_optimizers.TensorParallelMuon,emerging_optimizers.orthogonalized_optimizers.AdaptiveMuonTensor Parallel Adaptive Muon optimizer.
This class extends Muon by adding AdamW-style or NorMuon-style second moment accumulation after orthogonalization. This idea was first explored in D.E. Carlson, E. Collins, Ya-Ping Hsieh, L. Carin, and V. Cevher. Preconditioned spectral descent for deep learning. In Advances in neural information processing systems 28 (2015). The step() method is overridden to include second moment normalization logic.
- Parameters:
params – Iterable of parameters to optimize or dicts defining parameter groups.
lr – Learning rate.
momentum – The exponential decay rate for momentum.
nesterov – Whether to use Nesterov momentum.
weight_decay – Weight decay coefficient.
use_decoupled_weight_decay – Whether to use decoupled weight decay.
split_qkv – Whether to split QKV weights for orthogonalization.
is_qkv_fn – Function to determine if a tensor is a QKV weight.
qkv_split_shapes – Shapes for splitting QKV weights.
fp32_matmul_prec – Precision for FP32 matrix multiplication.
coefficient_type – The type of coefficient set to use for the Newton-Schulz iteration.
num_ns_steps – The number of iteration steps to use in the Newton-Schulz iteration.
scale_mode – The type of scale factor to use for the update.
extra_scale_factor – The additional scale factor to use for the update.
pg_collection – Process group collection for distributed training.
tp_mode – Tensor parallel mode (“blockwise”, “duplicated”, or “distributed”).
moment2_method – Method for second moment accumulation (“adamuon” or “normuon”).
beta2 – The exponential decay rate for second moment.
eps – Small constant for numerical stability.
Initialization
- step(
- closure: Optional[Callable] = None,
Step function
- core.optimizer.emerging_optimizers._kwargs_from_config(
- optimizer_cls: type,
- prefix: str,
- config,
Match
optimizer_cls.__init__parameters to config attributes.For each init parameter, looks for
{prefix}_{name}on config first, then falls back to{name}(unprefixed).selfandparamsare always skipped.
- core.optimizer.emerging_optimizers._muon_config_to_kwargs(
- config,
- model_chunks,
- pg_collection,
Convert OptimizerConfig to TensorParallelMuon constructor kwargs.
- core.optimizer.emerging_optimizers._adaptive_muon_config_to_kwargs(
- config,
- model_chunks,
- pg_collection,
Convert OptimizerConfig to TensorParallelAdaptiveMuon constructor kwargs.
- core.optimizer.emerging_optimizers._default_adam_based_eopt_config_to_kwargs(
- eopt_name,
- config,
- model_chunks,
- pg_collection,
Convert OptimizerConfig to default emerging optimizer constructor kwargs.