Distributed Optimizer#

The motivation for the distributed optimizer is to save memory by distributing the optimizer state evenly across data parallel ranks (https://arxiv.org/abs/1910.02054), versus the naive method of replicating the optimizer state across data parallel ranks.

Theoretical memory savings vary depending on the combination of the datatype of the model’s parameters ( param_dtype ) and main gradients accumulated across data-parallel replicas ( grad_dtype ). We always use fp32 main parameters for optimizer steps. In the current implementation, the theoretical number of bytes per parameter is (where d is the data parallel size):

Non-distributed optim Distributed optim fp16 parameters, fp16 gradients 20 4 + 16/d bf16 parameters, fp32 gradients 18 6 + 12/d fp32 parameters, fp32 gradients 16 8 + 8/d

Our implementation of the distributed optimizer uses contiguous buffers for parameters and main gradients; model gradients are copied over to the main gradients as soon as they are fully computed.

The figures below illustrate the distributed optimizer’s sharding scheme, and the key steps of the distributed optimizer’s parameter update:

Data flow#

Sharding scheme#