core.distributed.param_and_grad_buffer#
Module Contents#
Classes#
Enumeration for buffer type. |
|
Bucket to keep track of a subset of the model’s parameters and gradients. |
|
Put multiple buckets into a group so that their communications can be aggregated together. Provides functionality to register when params in the bucket group have grads ready to be synced; an asynchronous communication call is automatically launched when all params in the bucket group have grads ready. |
|
Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into
buckets with roughly |
Functions#
Shard buffer into data_parallel_world_size chunks of equal size. |
|
Automatically regroup the buckets of input buffers and return a list of bucket groups. |
Data#
API#
- core.distributed.param_and_grad_buffer.logger#
‘getLogger(…)’
- class core.distributed.param_and_grad_buffer.BufferType(*args, **kwds)#
Bases:
enum.EnumEnumeration for buffer type.
Initialization
- PARAM#
1
- GRAD#
2
- core.distributed.param_and_grad_buffer.shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int)#
Shard buffer into data_parallel_world_size chunks of equal size.
- class core.distributed.param_and_grad_buffer._ParamAndGradBucket(
- params: List[torch.nn.Parameter],
- param_data: Optional[torch.Tensor],
- grad_data: torch.Tensor,
- offset: int,
- numel_unpadded: int,
- gradient_scaling_factor: float,
- bucket_id: int,
Bucket to keep track of a subset of the model’s parameters and gradients.
- Parameters:
params – List of parameters whose gradients are collated in this bucket.
param_data – View in _ParamAndGradBuffer.param_data that this bucket is responsible for.
grad_data – View in _ParamAndGradBuffer.grad_data that this bucket is responsible for.
offset – Offset of this bucket’s view in the larger _ParamAndGradBuffer.
numel_unpadded – Number of unpadded elements in bucket.
gradient_scaling_factor – This factor is utilized to scale gradients prior to their communication. Its application is twofold: it facilitates the averaging of gradients and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
bucket_id – Index of bucket in buffer.
Initialization
- class core.distributed.param_and_grad_buffer._ParamAndGradBucketGroup(
- buckets: List[core.distributed.param_and_grad_buffer._ParamAndGradBucket],
- ddp_config: core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
- collective_group: torch.distributed.ProcessGroup,
- collective_group_size: int,
Put multiple buckets into a group so that their communications can be aggregated together. Provides functionality to register when params in the bucket group have grads ready to be synced; an asynchronous communication call is automatically launched when all params in the bucket group have grads ready.
- Parameters:
buckets – A list of buckets.
ddp_config – DistributedDataParallel config object.
collective_group – intra_distributed_optimizer_instance_group if using distributed optimizer, data_parallel_group if not.
collective_group_size – World size using the intra data-parallel group.
Initialization
- reset()#
Reset metadata in bucket group in preparation for the next iteration of training.
- check_grads(check_for_nan_or_inf, check_for_large)#
Make sure norm of grads in bucket are not NaN prior to data-parallel all-reduce / reduce-scatter.
- start_param_sync(force_sync: bool = False)#
Initiates all necessary param all-gathers for this bucket.
When ddp_config.overlap_param_gather is set to True, dispatches an asynchronous communication call (unless force_sync is True). When ddp_config.overlap_param_gather is set to False, makes synchronous call.
- Parameters:
force_sync (bool, optional) – force synchronous collective regardless of other settings if true.
- finish_param_sync(skip_next_bucket_dispatch: bool = False)#
Finishes param sync communication operation for this bucket. Dispatches next bucket’s param sync if available, unless skip_next_bucket_dispatch is True.
When ddp_config.overlap_param_gather is set to True, waits for asynchronous communication call to complete (and dispatches one if one is not already outstanding). Throws assertion error if ddp_config.overlap_param_gather is set to False.
- Parameters:
skip_next_bucket_dispatch (bool, optional) – if true, dispatch next bucket’s communication if available.
- start_grad_sync()#
Initiates grad sync (all-reduce or reduce-scatter) communication operations for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous communication call. When ddp_config.overlap_grad_reduce is set to False, makes synchronous call.
- finish_grad_sync()#
Finishes grad sync (all-reduce or reduce-scatter) communication operations for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous communication call to complete. When ddp_config.overlap_grad_reduce is set to False, makes synchronous call.
- register_grad_ready(param: torch.nn.Parameter)#
Registers grads for the passed-in param to be “ready” for grad sync.
When the number of microbatches is greater than 1, we only want to register grads as ready when processing the last microbatch and ddp_config.overlap_grad_reduce is True.
- class core.distributed.param_and_grad_buffer._ParamAndGradBuffer(
- ddp_config: core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
- param_dtype: torch.dtype,
- grad_dtype: torch.dtype,
- params: List[torch.nn.Parameter],
- data_parallel_group: torch.distributed.ProcessGroup,
- bucket_size: int,
- param_to_name: Dict[torch.nn.Parameter, str],
- gradient_scaling_factor: float,
- param_indices: List[int],
- nccl_ub: bool,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into buckets with roughly
bucket_sizeparameters each.- Parameters:
ddp_config – DistributedDataParallel config object.
param_dtype – Type of param tensor.
grad_dtype – Type of grad tensor.
params – List of parameters whose parameters and gradients are collated in the underlying tensor.
data_parallel_group – Data-parallel process group.
bucket_size – The rough size of each bucket in terms of number of parameters.
param_to_name – Mapping from
torch.nn.Parameterto name (for logging purposes).gradient_scaling_factor – This factor is utilized to scale gradients prior to their communication. Its application is twofold: it facilitates the averaging of gradients and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
param_indices – The index of each param among the params with same dtype, if a param is fp8, use its “fake” high precision dtype to determine which params have same dtype with it. These indices are needed when loading a non-native-fp8 checkpoint in native-fp8 mode.
Initialization
- scale_gradients(scaling_factor: float) None#
Scale the gradient data by
scaling_factor.
- _get(
- shape: torch.Size,
- start_index: int,
- buffer_type: core.distributed.param_and_grad_buffer.BufferType,
Return a tensor with the input
shapeas a view into the 1-D data starting atstart_index.
- _new_bucket(
- bucket_params: List[torch.nn.Parameter],
- start_index: int,
- end_index: int,
- numel_unpadded: int,
- bucket_id: int,
Helper function that creates a new bucket. Also updates param->bucket mapping.
- reset()#
Zero out the underlying grad_buffer.
- core.distributed.param_and_grad_buffer.partition_buckets(
- buffers: List[core.distributed.param_and_grad_buffer._ParamAndGradBuffer],
- force_single_bucket_group: bool = False,
Automatically regroup the buckets of input buffers and return a list of bucket groups.
In some scenarios, we need to put buckets from different buffers into a group so that their communication can be aggregated.
For example, when there are both fp8 weights and bf16 biases in the model and virtual pipeline parallelism is enabled, each model chunk will have an fp8 bucket and a bf16 bucket, which doubles the number of communication kernels, and because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back communications will prevent the overlap of communication kernels with computation kernels.
The grouping strategy is:
If force_single_bucket_group is True, put all buckets across all buffers into a single bucket group.
If force_single_bucket_group is False, when there is no fp8 buffer in the input buffers, let each bucket group have only one bucket.
If force_single_bucket_group is False, when using fp8 params, merge all non-fp8 buckets into the last fp8 bucket group.
Since the non-fp8 parameters (typically the biases of various layers) are relatively small, they are likely to be grouped into a single non-fp8 bucket.
The fp8 buckets start from the end of the model, i.e., the first bucket corresponds to the end of the model, while the last bucket corresponds to the beginning.
If we combine the non-fp8 bucket with the first fp8 bucket, we cannot initiate the reduce-scatter to synchronize gradients after the backward pass at the end of the model has completed. This is because we need to wait for the non-fp8 params from the beginning layers to obtain their gradients.
Combining the non-fp8 bucket with the last fp8 bucket can help avoid this issue.
- Parameters:
buffers (list) – list of input buffers.
single_bucket_group_per_buffer (bool, optional) – force group all buckets in each buffer into a single bucket group.