core.distributed.distributed_data_parallel#

Module Contents#

Classes#

DistributedDataParallel

DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping communication with backprop computation by breaking up full model’s gradients into smaller buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class also provides the option to do the gradient accumulation in a type other than the param type (e.g., fp32 for a bf16 model).

Data#

API#

core.distributed.distributed_data_parallel.logger#

‘getLogger(…)’

class core.distributed.distributed_data_parallel.DistributedDataParallel(
config: core.transformer.transformer_config.TransformerConfig,
ddp_config: core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
module: torch.nn.Module,
disable_bucketing: bool = False,
pg_collection: Optional[core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: core.distributed.data_parallel_base._BaseDataParallel

DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping communication with backprop computation by breaking up full model’s gradients into smaller buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class also provides the option to do the gradient accumulation in a type other than the param type (e.g., fp32 for a bf16 model).

Parameters:
  • config – Transformer config object.

  • ddp_config – DistributedDataParallel config object.

  • module – Underlying model.

  • disable_bucketing – If true, force assign all parameters to a single bucket. If false, use standard bucketing policy: assign parameters to smaller buckets and all-reduce per bucket if overlap_grad_reduce is True and pp_rank is 0.

  • pg_collection – Optional unified process group for distributed training.

Initialization

enable_forward_pre_hook()#

Enable forward pre-hooks needed for param all-gather overlap with forward compute.

disable_forward_pre_hook(param_sync: bool = True)#

Disable forward pre-hooks needed for param all-gather overlap with forward compute. Skip synchronous param all-gather if param_sync is False.

_make_forward_pre_hook()#

Create a forward pre-hook to wait on all-gather handles when necessary (i.e., when a module uses a parameter in a bucket with a still incomplete all-gather).

_make_backward_post_hook(param: torch.nn.Parameter)#

Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when ready (i.e., when all grads in a bucket have been computed in all microbatches in a batch).

no_sync()#

Context manager that turns off gradient synchronization.

start_param_sync(
*unused,
force_sync: bool = False,
force_dispatch: bool = False,
)#

Initiates param sync (all-gather) communication operations for all model parameters.

By default, when overlap_param_gather is set to True, dispatches asynchronous communication calls; when overlap_param_gather is set to False, calls synchronous communication ops. Can override this default behavior using flags below.

Parameters:
  • force_sync (bool, optional) – force synchronous collective regardless of other settings.

  • force_dispatch (bool, optional) – force dispatch regardless of other settings.

start_grad_sync(*unused)#

Initiates grad sync (all-reduce or reduce-scatter) communication operations for all model gradients.

When overlap_grad_reduce is set to True, dispatches asynchronous communication calls. When overlap_grad_reduce is set to False, calls synchronous communication ops.

finish_grad_sync()#

Finishes grad sync (all-reduce or reduce-scatter) communication operations for all model gradients.

When overlap_grad_reduce is set to True, waits for asynchronous communication calls to complete. When overlap_grad_reduce is set to False, calls synchronous communication ops.

scale_gradients(scaling_factor: float)#

Scale all gradients inside the buffers by scaling_factor.

zero_grad_buffer()#

Zeros out all grad buffers. Needs to be called at the beginning of each training iteration.

broadcast_params()#

Syncs parameters across all DP ranks.