core.distributed.distributed_data_parallel#
Module Contents#
Classes#
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping communication with backprop computation by breaking up full model’s gradients into smaller buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class also provides the option to do the gradient accumulation in a type other than the param type (e.g., fp32 for a bf16 model). |
Data#
API#
- core.distributed.distributed_data_parallel.logger#
‘getLogger(…)’
- class core.distributed.distributed_data_parallel.DistributedDataParallel(
- config: core.transformer.transformer_config.TransformerConfig,
- ddp_config: core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
- module: torch.nn.Module,
- disable_bucketing: bool = False,
- pg_collection: Optional[core.process_groups_config.ProcessGroupCollection] = None,
Bases:
core.distributed.data_parallel_base._BaseDataParallelDDP wrapper which stores grads in contiguous buffers. Also has option of overlapping communication with backprop computation by breaking up full model’s gradients into smaller buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class also provides the option to do the gradient accumulation in a type other than the param type (e.g., fp32 for a bf16 model).
- Parameters:
config – Transformer config object.
ddp_config – DistributedDataParallel config object.
module – Underlying model.
disable_bucketing – If true, force assign all parameters to a single bucket. If false, use standard bucketing policy: assign parameters to smaller buckets and all-reduce per bucket if overlap_grad_reduce is True and pp_rank is 0.
pg_collection – Optional unified process group for distributed training.
Initialization
- enable_forward_pre_hook()#
Enable forward pre-hooks needed for param all-gather overlap with forward compute.
- disable_forward_pre_hook(param_sync: bool = True)#
Disable forward pre-hooks needed for param all-gather overlap with forward compute. Skip synchronous param all-gather if
param_syncis False.
- _make_forward_pre_hook()#
Create a forward pre-hook to wait on all-gather handles when necessary (i.e., when a module uses a parameter in a bucket with a still incomplete all-gather).
- _make_backward_post_hook(param: torch.nn.Parameter)#
Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when ready (i.e., when all grads in a bucket have been computed in all microbatches in a batch).
- no_sync()#
Context manager that turns off gradient synchronization.
- start_param_sync(
- *unused,
- force_sync: bool = False,
- force_dispatch: bool = False,
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication calls; when overlap_param_gather is set to False, calls synchronous communication ops. Can override this default behavior using flags below.
- Parameters:
force_sync (bool, optional) – force synchronous collective regardless of other settings.
force_dispatch (bool, optional) – force dispatch regardless of other settings.
- start_grad_sync(*unused)#
Initiates grad sync (all-reduce or reduce-scatter) communication operations for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication calls. When overlap_grad_reduce is set to False, calls synchronous communication ops.
- finish_grad_sync()#
Finishes grad sync (all-reduce or reduce-scatter) communication operations for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication calls to complete. When overlap_grad_reduce is set to False, calls synchronous communication ops.
- scale_gradients(scaling_factor: float)#
Scale all gradients inside the buffers by
scaling_factor.
- zero_grad_buffer()#
Zeros out all grad buffers. Needs to be called at the beginning of each training iteration.
- broadcast_params()#
Syncs parameters across all DP ranks.