core.distributed.data_parallel_base#
Module Contents#
Classes#
A template class for DistributedDataParallel implementations. |
API#
- class core.distributed.data_parallel_base._BaseDataParallel(
- config: core.transformer.transformer_config.TransformerConfig,
- module: torch.nn.Module,
Bases:
core.transformer.module.MegatronModuleA template class for DistributedDataParallel implementations.
Initialization
- forward(*inputs, **kwargs)#
Calls the wrapped module’s forward() method.
- no_sync()#
Context manager that turns off gradient synchronization.
- start_grad_sync(*unused)#
Initiates grad sync (all-reduce or reduce-scatter) communication operations for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication calls. When overlap_grad_reduce is set to False, calls synchronous communication ops.
- scale_gradients(scaling_factor: float) None#
Scale all gradients inside the buffers by
scaling_factor.
- finish_grad_sync()#
Finishes grad sync (all-reduce or reduce-scatter) communication operations for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication calls to complete. When overlap_grad_reduce is set to False, calls synchronous communication ops.
- zero_grad_buffer()#
Zeros out all grad buffers. Needs to be called at the beginning of each training iteration.
- broadcast_params()#
Syncs parameters across all DP ranks.
- state_dict(prefix='', keep_vars=False, destination=None)#
Returns a dictionary containing references to the whole state of the wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.
- state_dict_for_save_checkpoint(prefix='', keep_vars=False)#
Returns wrapped module’s state_dict for checkpoint saving.
- load_state_dict(state_dict, strict=True)#
Copies parameters and buffers from state_dict into the wrapped module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.