core.distributed.reduce_scatter_with_fp32_accumulation#
Module Contents#
Classes#
Work handle to return to user when using reduce_scatter_with_fp32_accumulation with async_op=True. |
Functions#
Reduce-scatter with FP32 accumulation. |
API#
- class core.distributed.reduce_scatter_with_fp32_accumulation._ReduceScatterWithFP32AccumulationWorkHandle(
- all_to_all_handle: Any,
- all_to_all_output_tensor: torch.Tensor,
- output_tensor: torch.Tensor,
- world_size: int,
Work handle to return to user when using reduce_scatter_with_fp32_accumulation with async_op=True.
Initialization
Initialize WorkHandle object.
- wait()#
Wait until communication (and associated computation) is completed.
- core.distributed.reduce_scatter_with_fp32_accumulation.reduce_scatter_with_fp32_accumulation(
- output_tensor: torch.Tensor,
- input_tensor: torch.Tensor,
- op: torch.distributed.ReduceOp,
- group: torch.distributed.ProcessGroup,
- async_op: bool,
Reduce-scatter with FP32 accumulation.
Collects input_tensor in lower precision using an all-to-all, then locally accumulates in FP32 precision, then downcasts final sum back into right location in input_tensor.
- Parameters:
output_tensor (torch.Tensor) – Output tensor with reduce-scattered output (only the shard).
input_tensor (torch.Tensor) – Input tensor that needs to be reduce-scattered.
op (torch.distributed.ReduceOp) – Only torch.distributed.ReduceOp.SUM is supported.
group (torch.distributed.ProcessGroup) – Process group to use for reduce-scatter.
async_op (bool) – Only False is supported right now.