distributed package
This package contains various utilities to finalize model weight gradients on each rank before the optimizer step. This includes a distributed data parallelism wrapper to all-reduce or reduce-scatter the gradients across data-parallel replicas, and a finalize_model_grads method to synchronize gradients across different parallelism modes (e.g., ‘tied’ layers on different pipeline stages, or gradients for experts in a MoE on different ranks due to expert parallelism).
Model wrapper for distributed data parallelism. Stores gradients in a contiguous buffer, and supports the option of overlapping communication (all-reduce or reduce-scatter) with backprop computation by breaking up full model’s gradients into smaller buckets and running all-reduce / reduce-scatter on each bucket asynchronously.
- class core.distributed.distributed_data_parallel.DistributedDataParallel(*args: Any, **kwargs: Any)
Bases:
core.distributed.data_parallel_base._BaseDataParallel
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping communication with backprop computation by breaking up full model’s gradients into smaller buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class also provides the option to do the gradient accumulation in a type other than the param type (e.g., fp32 for a bf16 model).
- Parameters
config – Transformer config object.
ddp_config – DistributedDataParallel config object.
module – Underlying model.
disable_bucketing – If true, force assign all parameters to a single bucket. If false, use standard bucketing policy: assign parameters to smaller buckets and all-reduce per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
- broadcast_params()
Syncs parameters across all DP ranks.
- disable_forward_pre_hook(param_sync: bool = True)
Disable forward pre-hooks needed for param all-gather overlap with forward compute. Skip synchronous param all-gather if param_sync is False.
- enable_forward_pre_hook()
Enable forward pre-hooks needed for param all-gather overlap with forward compute.
- finish_grad_sync()
Finishes grad sync (all-reduce or reduce-scatter) communication operations for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication calls to complete. When overlap_grad_reduce is set to False, calls synchronous communication ops.
- no_sync()
Context manager that turns off gradient synchronization.
- scale_gradients(scaling_factor: float)
Scale all gradients inside the buffers by scaling_factor.
- start_grad_sync(*unused)
Initiates grad sync (all-reduce or reduce-scatter) communication operations for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication calls. When overlap_grad_reduce is set to False, calls synchronous communication ops.
- start_param_sync(*unused, force_sync: bool = False, force_dispatch: bool = False)
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication calls; when overlap_param_gather is set to False, calls synchronous communication ops. Can override this default behavior using flags below.
- Parameters
force_sync (bool, optional) – force synchronous collective regardless of other settings.
force_dispatch (bool, optional) – force dispatch regardless of other settings.
- zero_grad_buffer()
Zeros out all grad buffers. Needs to be called at the beginning of each training iteration.
Finalize model gradients for optimizer step across all used parallelism modes. Synchronizes the all-reduce / reduce-scatter of model gradients across DP replicas, all-reduces the layernorm gradients for sequence parallelism, embedding gradients across first and last pipeline stages (if not tied), and expert gradients for expert parallelism.
- core.distributed.finalize_model_grads.finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None)
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, embedding grads across first and last pipeline stages (if not tied), scale gradients by num_tokens.
Contains functionality to synchronize gradients across different ranks before optimizer step.