core.optimizer.distrib_optimizer#
Megatron distributed optimizer.
Module Contents#
Classes#
A range represents a start and end points for indexing a shard from a full tensor. |
|
Distributed optimizer, for all data types (fp16, bf16, and fp32). |
Data#
API#
- core.optimizer.distrib_optimizer.HAVE_APEX_OR_TE#
True
- core.optimizer.distrib_optimizer.USING_TE_OPTIMIZER#
False
- core.optimizer.distrib_optimizer.USING_APEX_OPTIMIZER#
False
- core.optimizer.distrib_optimizer.logger#
‘getLogger(…)’
- class core.optimizer.distrib_optimizer.Range(start: int, end: int)#
A range represents a start and end points for indexing a shard from a full tensor.
- Parameters:
start (int) – Start index.
end (int) – End index.
Initialization
- normalize(start: int = 0)#
Shift start/end indexes to start at new start index.
Both start and end indexes will be shifted by [new start] - [old start].
- Parameters:
start (int) – New start index.
- __str__()#
- __repr__()#
- __len__()#
- class core.optimizer.distrib_optimizer.DistributedOptimizer(
- optimizer: torch.optim.Optimizer,
- config: core.optimizer.optimizer_config.OptimizerConfig,
- grad_scaler: core.optimizer.grad_scaler.MegatronGradScaler,
- init_state_fn: Optional[Callable],
- model_chunks: List[core.transformer.module.MegatronModule],
- per_model_buffers: Dict[int, List[core.distributed.param_and_grad_buffer._ParamAndGradBuffer]],
- data_parallel_group: torch.distributed.ProcessGroup,
- data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup],
- data_parallel_group_idx: int,
- distributed_optimizer_instance_id: int,
Bases:
core.optimizer.optimizer.MixedPrecisionOptimizerDistributed optimizer, for all data types (fp16, bf16, and fp32).
See init() below for argument details.
Initialization
Distributed optimizer, for all data types (fp16, bf16, and fp32).
The steps in this method create the core mapping between param and grad buffers, parameters, and parameter shard ranges, that is needed for converting between model param indexes and main parameter shard indexes. This method also updates the optimizer parameter groups with the newly created shards.
- Parameters:
optimizer (torch.optim.Optimizer) – base optimizer such as Adam or SGD.
config (OptimizerConfig) – configuration object for optimizer.
grad_scaler (MegatronGradScaler) – used for scaling gradients. Note that this can be None. This case happens when
bf16 = Trueand we don’t use any loss scale. Note that forbf16 = True, we can have a constant gradient scaler. Also forbf16 = False, we always require a grad scaler.init_state_fn (Callable, optional) – function to initialize state in the optimizer.
model_chunks (List[MegatronModule]) – list of model chunks.
per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]) – the implementation of the distributed optimizer is centered on using a contiguous buffer for communicating grads & params between the model state and the optimizer state. You can find a more detailed description in https://github.com/NVIDIA/Megatron-LM/blob/main/docs/source/distrib_optimizer.md.
data_parallel_group (torch.distributed.ProcessGroup) – data-parallel group to use to all-gather params after optimizer.step().
data_parallel_group_gloo (torch.distributed.ProcessGroup) – gloo data-parallel group (used in checkpoint loading and saving).
data_parallel_group_idx (int) – index in data-parallel group (used by distributed checkpointing logic).
distributed_optimizer_instance_id (int) – index of the Distributed Optimizer instance.
- checkpoint_fully_reshardable_formats: set[str]#
None
- classmethod _build_model_gbuf_param_range_map(
- param_world_index_map: Dict[torch.nn.Parameter, Tuple],
- gbuf_world_range: core.optimizer.distrib_optimizer.Range,
- bucket_offset: int,
Build mapping from param reference to grad buffer shard ranges.
This method builds a mapping from parameter references to grad buffer shard ranges, specific to each data-parallel (DP) rank’s set of ‘owned’ parameters. Each grad buffer (padded to be an even multiple of DP-world-size) is conceptually divided into DP-world-size contiguous regions, where each DP rank ‘owns’ a contiguous region. Ownership in this sense means DP rank is responsible for reducing the relevant subset of grads, and updating the relevant subset of params.
This conceptual partitioning of the grad buffer does NOT respect parameter boundaries, and as such it is assumed that each created range references a shard (or subset) of the full parameter. It is easiest to think of each DP rank as operating (i.e., reducing, gathering) purely on views into the grad buffer, for all model-to- main & main-to-model operations.
This method creates four ranges:
The param’s range within the entire grad buffer (i.e., world index).
The param’s range within the relevant grad bucket’s buffer.
The param’s range within the DP rank’s local view of the grad buffer.
The param’s range within itself (i.e., its shard).
- classmethod _build_model_gbuf_range(
- param_and_grad_buffer: core.distributed.param_and_grad_buffer._ParamAndGradBuffer,
- bucket_index: int,
Build mapping between params and their grad buffers.
This method does the initial setup for the method above. This setup includes determining the shard ranges into the param_and_grad_buffer for each data-parallel (DP) rank. Each DP rank keeps range info for all other DP ranks, for the purpose of creating args for reduce-scatter and all-gather.
- classmethod _build_gbuf_range_map(
- param_and_grad_buffer: core.distributed.param_and_grad_buffer._ParamAndGradBuffer,
Build mapping between params and their grad buffers. These mappings are partitioned according to data type.
Iterate through all buckets of grad buffer to construct param ranges that this rank “owns” (the dp_rank’th shard of each bucket, where each shard is 1/dp_world_size of the bucket).
- Parameters:
param_and_grad_buffer (_ParamAndGradBuffer) – buffer to build mapping for.
- classmethod _build_model_param_gbuf_map(
- gbuf_ranges: List[Dict],
Create a reverse of the gbuf_ranges, for referencing in opposite direction.
- classmethod _build_optimizer_group_ranges(
- param_groups: List[Dict],
- gbuf_ranges: List[Dict],
Create optimizer groups.
Given the set of parameter shard ranges that are owned by the current data-parallel (DP) rank, gather the set of parameters that will be used (in the method below) to create the current DP’s optimizer groups.
- classmethod _build_model_and_main_param_groups(
- gbuf_ranges: List[Dict],
- param_gbuf_map: Dict[torch.nn.Parameter, Tuple],
- opt_group_ranges: List,
- config: core.optimizer.optimizer_config.OptimizerConfig,
Create main parameter groups needed for the optimizer step.
These groups encompass both: 1) groups used by this class, for reducing/gather, and 2) groups used by the inner optimizer for the parameter update. Given that the conceptual grad buffer partitioning (created in earlier method) doesn’t respect parameter boundaries, the optimizer operates on shards of the model parameters, rather than the full parameters.
- _get_model_param_range_map(param: torch.nn.Parameter)#
Given a model param, get the index sub-range of the param that this data-parallel rank owns.
- get_grad_stats_parallel_group() torch.distributed.ProcessGroup#
With the distributed optimizer, gradient statistics (num_zeros & norm) are reduced over all ranks in the distributed optimizer instance (versus only the model-parallel ranks with the non-distributed optimizer).
- state_dict()#
The state dict contains all non-DP-rank-dependent (i.e., non-parameter- related) optimizer variables. The returned state dict can be stored in the standard model/RNG checkpoint file. The parameter and dependent optimizer state (e.g., exp_avg, exp_avg_sq) are stored in a separate checkpoint file by calling ‘save_parameter_state()’.
- load_state_dict(state_dict)#
Load the state dict.
As detailed in state_dict(), the state dict contains all non- parameter-related variables. This method is notably longer than state_dict(), because the Torch optimizers state has yet to be allocated at this point, and so we must do a cross referencing between the optimizers state (and the ordering it expects for parameter state) and this DP rank’s shards. The optimizer at this point does not contain any tensor dimension information, so we must get these dimensions from the DP shards mapped during DistributedOptimizer.init().
The tensor parameter state is loaded via load_parameter_state(), and so this method also must populate the loaded state dict with dummy tensor data (i.e., via torch.empty() below). This will be overwritten during load_parameter_state().
** Note: Torch optimizer’s state structure. ** The Torch optimizer stores its state in two levels. The top level is a list of groups, where each group contains a list of integer indexes (corresponding to parameters) that index into a master parameter list that is shared by all groups. As such, three values are necessary for maintaining this ordering:
group_index : The group to which a parameter belongs.
group_order : The index of a parameter within its group.
state_order : The index of a parameter within the shared parameter list.
- _get_main_param_and_optimizer_states(model_param)#
Return a dict containing the main param and optimizer states corresponding to the input model_param.
The structure of the returned dict: tensors = { “param”: torch.Tensor “exp_avg”: torch.Tensor “exp_avg_sq”: torch.Tensor }
- _set_main_param_and_optimizer_states(model_param, tensors)#
Set the main param and optimizer states corresponding to the input model_param.
The structure of the input
tensors: tensors = { “param”: torch.Tensor “exp_avg”: torch.Tensor “exp_avg_sq”: torch.Tensor }
- get_parameter_state_dp_reshardable()#
Get internal representation of parameter state without any copies and modifications.
This is referred to as “fully sharded bucket space” because the optimizer state is fully sharded (e.g. no gather involved) and bucket-centric (the state follows the internal structure of the Distributed Optimizer buckets) as opposed to model-centric (typical structure of PyT optimizers)
- get_parameter_state_dp_zero(
- use_gloo_comm: bool = True,
- empty_data: bool = False,
- return_on_all_ranks: bool = False,
Get parameter state (i.e., parameter & optimizer tensors).
This method performs two steps:
For each DP rank, copy param & optimizer shards to contiguous CPU buffers (e.g., one buffer each for main_param, exp_avg, and exp_avg_sq).
Gather contiguous buffers on DP rank 0 and concatenate to world buffers.
- Parameters:
use_gloo_comm (bool, optional) – Whether to use Gloo communication for tensors gather. Defaults to True. Has effect only for non-FSDP case.
empty_data (bool, optional) – Whether to fill world tensors with actual data. Empty world tensors are used during checkpoint loading. Defaults to False. Has effect only for non-FSDP case.
return_on_all_ranks (bool, optional) – Whether to return the state dict on all ranks. If False, DP != 0 ranks will return None. Defaults to False. Has effect only for non-FSDP case. Returning the whole state dict on all ranks allows to utilize parallel saving and loading when used for sharded state dict creation.
- Returns:
optimizer state dict on DP rank 0, or all ranks if return_on_all_ranks. Returns None on non-zero DP ranks when return_on_all_ranks=False.
- Return type:
dict or None
- save_parameter_state(filename: str)#
Save the distributed parameter state on DP rank 0.
- Parameters:
filename (str) – path to save parameter state to.
- _init_optimizer_states_with_dummy_values()#
- _param_name(param: torch.nn.Parameter) str#
Get the name of the parameter.
- sharded_state_dict(
- model_sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict = {},
- is_loading: bool = False,
- sharding_type: Optional[str] = None,
- metadata: Optional[dict] = None,
Chooses between 3 param state sharding implementations as requested by
metadata['distrib_optim_sharding_type'].Sharding type can be one of:
‘dp_reshardable’: Sharded state dict where each noncontiguous buffer is a separate ShardedTensor. Results in fully parallel save and load without any inter-process communication or intermediate buffers/copies. Since the format relies on the internal DistributedOptimizer structure, it allows checkpoint resharding only in DP dimension.
‘fully_reshardable’: During checkpoint save (
is_loading=False) gathers all DistributedOptimizer buffers on DP rank 0 and transforms them into a canonical state representation similar to a regular optimizer where each model param corresponds to one or more optimizer state tensors of the same shape (possibly different precision). During checkpoint load each rank loads a superset of the required state and does rank specific flattening and slicing.‘fsdp_dtensor’: Sharded state dict where each parameter is a separate PyTorch DTensor. This is the default and recommended implementation for the distributed optimizer when using the megatron fsdp training.
Deprecated sharding formats:
‘dp_zero_gather_scatter’: Naive implementation which reuses gather/scatter from the legacy ckpt format. During saving, gathers the parameters state on DP rank 0 and saves a ShardedObject with fixed TPxPP structure. During loading, loads the saved data on DP rank 0 (None on other ranks). Relies on the parameters scatter done in load_state_dict.
‘fully_sharded_model_space’: Sharded state dict where each parameter is a separate ShardedTensor, which is a flattened subset of the canonical state representation. Results in fully parallel save and load without any inter-process communication or intermediate buffers/copies.
Regular state dict parameters are saved on DP rank 0 and loaded on all ranks.
- _param_groups_to_param2group_meta(
- param_groups: list[dict[str, Any]],
Convert a parameter group to a mapping of parameter names to group metadata.
- _param2group_meta_to_param_groups(
- param_to_group_meta: dict[str, Any],
- param_groups: list[dict[str, Any]],
- strict: bool = True,
Convert a mapping of parameter names to group metadata to a list of parameter groups.
- sharded_param_state_fsdp_dtensor(is_loading: bool = False)#
Sharded state dict where each parameter is a separate PyTorch DTensor.
- sharded_param_state_dp_zero(
- model_sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- is_loading: bool = False,
- metadata: Optional[dict] = None,
Naive implementation which reuses gather/scatter from the legacy ckpt format.
During saving, gathers the parameters state on DP rank 0 and saves a ShardedObject with fixed TPxPP structure. During loading, loads the saved data on DP rank 0 (None on other ranks). Relies on the parameters scatter done in load_state_dict.
- sharded_param_state_fully_reshardable(
- model_sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- is_loading: bool = False,
- metadata: Optional[dict] = None,
Exchange based format in model space representation.
fully_reshardableformat involves gathering the tensors on DP rank 0 during save. Flat DistOpt buffers are unflattened and reshaped into model param like sizes. This results in a state dict similar to a regular optimizer one, where each param of shape (X, Y, Z) has corresponding ‘param’, ‘exp_avg’ and ‘exp_avg_sq’ tensors of shape (X, Y, Z) in the optimizer state dict.During loading there is no data exchange - each rank requests to load the whole state dict (and flattens and trims the tensors afterwards). It is recommended to use fully parallel loading which will parallelize the load and avoid duplicated read from storage.
- Parameters:
model_sharded_state_dict (ShardedStateDict) – model sharded state dict
is_loading (bool, optional) – Whether the optimizer sharded state dict is used for loading or saving. Defaults to False.
metadata (dict, optional) – metadata passed to sharded_state_dict method. Allows some detailed control over the sharded state dict creation with
distrib_optim_fully_reshardable_mem_efficientflag which enables memory efficient exchange. By default (False), data will be all_gathered with NCCL to all ranks which allows to further parallelize the save and load, but can use more memory. In memory efficient version (True) data is gather with Gloo and returned only on DP rank 0 (which prevent save/load parallelization along DP). The checkpoint storage structure can differ between those two flags, but from MCore perspective they are interchangeable.
- Returns:
optimizer sharded state dict if memory efficient mode is off (see flag
distrib_optim_fully_reshardable_mem_efficientexplanation above) or during checkpoint loading (is_loading). Otherwise, the sharded state dict is returned only on DP rank 0 (None on other ranks).- Return type:
ShardedStateDict or None
- sharded_param_state_dp_reshardable(
- model_sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- is_loading: bool = False,
- metadata: Optional[dict] = None,
Sharded state dict where each noncontiguous buffer is a separate ShardedTensor.
Results in fully parallel save and load without any inter-process communication or intermediate buffers/copies.
Stores optimizer state in the format that corresponds to the internal Distributed Optimizer format, i.e. in buckets. Each buckets consists of state parameters and potentially some padding:
intra-param padding
param 1
intra-param padding
param 2
intra-param padding
param …
intra-param padding
param N
intra-param padding
bucket padding to some DP multiple
Different buckets are assigned a different ShardedTensor key. Within each bucket, each param and each padding above is represented with a different ShardedTensor object sharing the same key (so, corresponding to the same tensor in the checkpoint).
For checkpointing, we include the intra-param padding for correctness but we must discard the last padding to DP multiple, because that might change during DP resharding - we want the checkpoint tensor to always have size
gbuf_world_numel_unpaddedwhich means everything except for the last padding above.
- sharded_param_state_fs_model_space(
- model_sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- is_loading: bool = False,
- metadata: Optional[dict] = None,
Sharded state dict where each buffer is mapped to corresponding model param.
In this approach the optimizer state tensors are directly related to model parameters by linking them with metadata from
model_sharded_state_dict. This will allow changing TP and PP while using DistOpt (as with other optimizers).
- load_parameter_state_from_dp_reshardable(state_dict)#
Loads the parameter state from an internal representation.
Inverse of the
get_parameter_state_dp_reshardablemethod.
- load_parameter_state_from_fs_model_space(state_dict)#
Loads the parameter state from a “model space” representation.
Inverse of the
sharded_param_state_fs_model_spacemethod.
- classmethod _update_legacy_world_tensors(old_tensors, new_numels)#
Reshard buckets (where each bucket is a tensor) to new target numels, where the total numel remains the same.
- load_parameter_state_from_dp_zero_legacy(state_dict)#
Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank, using the legacy checkpoint format as described below.
The difference between this method and
load_parameter_state_from_dp_zero_modern()is that this method is used for updating the format of checkpoints that were saved using code from before Feb 13, 2024. Starting on this date, a new format was used (i.e., different format for the parameter mapping and bucket sharding).Use arg
--ckpt-convert-update-legacy-dist-opt-formatto call this method, along with--ckpt-convert-formatand--ckpt-convert-saveto update a legacy-format checkpoint to the modern format.
- load_parameter_state_from_dp_zero(
- state_dict,
- *,
- update_legacy_format=False,
Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank, using the new checkpoint format with coalesced state across buckets.
This method performs the reverse of get_parameter_state_dp_zero():
Scatter contiguous buffers from DP rank 0 to each DP rank (each DP rank receives its relevant subset of the world buffers).
For each DP rank, copy param & optimizer shards from contiguous CPU buffers. (e.g., one buffer each for main_param, exp_avg, and exp_avg_sq).
- load_parameter_state_from_fully_reshardable(state_dict: dict)#
Load counterpart of sharded_param_state_fully_reshardable.
Iterates over the
state_dicttensors (in the same order assharded_param_state_fully_reshardablewhich determines the state dict tensors order), flattens and trims them according to local param ranges.- Parameters:
state_dict (dict) – loaded optimizer state dict
- split_state_dict_if_needed(state_dict)#
When “–fp8-param-gather” is disabled, weights and biases are stored in the same
_ParamAndGradBuffer. So, when saving a checkpoint, the optimizer’s main parameters are saved in a single continuous tensor (this also applies to “exp_avg” and “exp_avg_sq”).However, when “–fp8-param-gather” is enabled, weights(in fp8 dtype) and biases(in bf16/fp16 dtype) are stored in separate
_ParamAndGradBuffer. Therefore, when we enabled “–fp8-param-gather”, and want to load a checkpoint saved without “–fp8-param-gather”, we need to split the weights(fp8) and biases(bf16/fp16) in the static_dict into two separate tensors.
- load_parameter_state(filename: str, *, update_legacy_format=False)#
Load the distributed parameter state from disk.
- Parameters:
filename (str) – path to load parameter state from.
- zero_grad(set_to_none: bool = True)#
Zeroes grads for the model related parameters, i.e., model_float16_groups and model_fp32_groups. We additionally zero the remaining groups as a memory optimization to reduce fragmentation; in the case of set_to_none==True, the space used by this field can be safely deallocated.
- Parameters:
set_to_none (bool) – if true, set grads to None.
- _collect_main_grad_data_for_unscaling()#
Note: this should be equivalent to the float-16 optimizer’s method, but written differently, so the two should be combined.
- _get_model_and_main_params_data_float16()#
Get aligned list of model and main params.
- _get_fp8_params_and_shard_fp32_from_fp8()#
Get lists of FP8 model params, corresponding shard main params, and the starting index of the shard main param in the FP8 param. Parameters in all three lists are in the same order.
- _copy_model_grads_to_main_grads()#
Copy model grads to main grads.
Since this step follows a reduce-scatter through the DDP’s grad buffer, this method is responsible for copying the updated grads from the grad buffer to the main shard’s grad field.
- _copy_main_params_to_model_params()#
Copy main params to model params.
Since this step is followed by an all-gather through the DDP’s grad buffer, this method is responsible for copying the updated params from the main shards into the correct position in the grad buffer.
- _copy_main_params_to_param_buffer()#
This function is only used for MXFP8 params. Copy FP32 main params directly to param buffer for param all-gather since param buffer is not mapped to model params for MXFP8 case.
- _build_model_param_to_state_dict_param_map(state_dict)#
Create a map from model params to tensors in state_dict based on their names.
- _copy_model_params_to_main_params(state_dict=None)#
Copy model params to main params.
During finetuning, this method is used to reload the main params from the model params. This copy does not make use of the grad buffer as an intermediary.
- step_with_ready_grads() bool#
Step the optimizer with ready gradients, return successful. Under the hood, either launch synchronous param all-gathers or get ready to launch asynchorous all-gathers that get overlapped with the next forward pass.