core.distributed.reduce_scatter_with_fp32_accumulation#

Module Contents#

Classes#

_ReduceScatterWithFP32AccumulationWorkHandle

Work handle to return to user when using reduce_scatter_with_fp32_accumulation with async_op=True.

Functions#

reduce_scatter_with_fp32_accumulation

Reduce-scatter with FP32 accumulation.

API#

class core.distributed.reduce_scatter_with_fp32_accumulation._ReduceScatterWithFP32AccumulationWorkHandle(
all_to_all_handle: Any,
all_to_all_output_tensor: torch.Tensor,
output_tensor: torch.Tensor,
world_size: int,
)#

Work handle to return to user when using reduce_scatter_with_fp32_accumulation with async_op=True.

Initialization

Initialize WorkHandle object.

wait()#

Wait until communication (and associated computation) is completed.

core.distributed.reduce_scatter_with_fp32_accumulation.reduce_scatter_with_fp32_accumulation(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: torch.distributed.ReduceOp,
group: torch.distributed.ProcessGroup,
async_op: bool,
)#

Reduce-scatter with FP32 accumulation.

Collects input_tensor in lower precision using an all-to-all, then locally accumulates in FP32 precision, then downcasts final sum back into right location in input_tensor.

Parameters:
  • output_tensor (torch.Tensor) – Output tensor with reduce-scattered output (only the shard).

  • input_tensor (torch.Tensor) – Input tensor that needs to be reduce-scattered.

  • op (torch.distributed.ReduceOp) – Only torch.distributed.ReduceOp.SUM is supported.

  • group (torch.distributed.ProcessGroup) – Process group to use for reduce-scatter.

  • async_op (bool) – Only False is supported right now.