deeplearning/modulus/modulus-core/_modules/modulus/distributed/utils.html
Source code for modulus.distributed.utils
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO this also needs more docstrings
from typing import List, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from .manager import DistributedManager
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
# treat trivial case first
if num_chunks == 1:
return [size]
# first, check if we can split using div-up to balance the load:
chunk_size = (size + num_chunks - 1) // num_chunks
last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
if last_chunk_size == 0:
# in this case, the last shard would be empty, split with floor instead:
chunk_size = size // num_chunks
last_chunk_size = size - chunk_size * (num_chunks - 1)
# generate sections list
sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]
return sections
[docs]def get_memory_format(tensor):
"""Gets format for tensor"""
if tensor.is_contiguous(memory_format=torch.channels_last):
return torch.channels_last
else:
return torch.contiguous_format
[docs]def pad_helper(tensor, dim, new_size, mode="zero"):
"""Util for padding tensors"""
ndim = tensor.ndim
dim = (dim + ndim) % ndim
ndim_pad = ndim - dim
output_shape = [0 for _ in range(2 * ndim_pad)]
orig_size = tensor.shape[dim]
output_shape[1] = new_size - orig_size
tensor_pad = F.pad(tensor, output_shape, mode="constant", value=0.0)
if mode == "conj":
lhs_slice = [
slice(0, x) if idx != dim else slice(orig_size, new_size)
for idx, x in enumerate(tensor.shape)
]
rhs_slice = [
slice(0, x) if idx != dim else slice(1, output_shape[1] + 1)
for idx, x in enumerate(tensor.shape)
]
tensor_pad[lhs_slice] = torch.flip(
torch.conj(tensor_pad[rhs_slice]), dims=[dim]
)
return tensor_pad
[docs]def truncate_helper(tensor, dim, new_size):
"""Util for truncating"""
input_format = get_memory_format(tensor)
ndim = tensor.ndim
dim = (dim + ndim) % ndim
output_slice = [
slice(0, x) if idx != dim else slice(0, new_size)
for idx, x in enumerate(tensor.shape)
]
tensor_trunc = tensor[output_slice].contiguous(memory_format=input_format)
return tensor_truncdef split_tensor_along_dim(tensor, dim, num_chunks):
if dim >= tensor.dim():
raise ValueError(
f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
)
if tensor.shape[dim] < num_chunks:
raise ValueError(
"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
{num_chunks} chunks. Empty slices are currently not supported."
)
# get split
sections = compute_split_shapes(tensor.shape[dim], num_chunks)
tensor_list = torch.split(tensor, sections, dim=dim)
return tensor_list
[docs]@torch.no_grad()
def reduce_loss(loss: float, dst_rank: int = 0, mean: bool = True): # pragma: no cover
"""Reduces loss from all processes to destination rank for logging.
Parameters
----------
loss : float
loss value
dst_rank : int, Optional
destination rank to redce to, by default 0.
mean : bool, Optional
Calculate the mean of the losses gathered, by default True.
Raises
------
Exception
If DistributedManager has yet to be initialized
"""
if not DistributedManager.is_initialized():
raise Exception(
"Distributed manager should be initialized when using reduce_loss"
)
distmng = DistributedManager()
loss = torch.Tensor([loss]).to(distmng.device)
# For serial runs, just return the current loss!
if distmng.world_size == 1:
return float(loss)
op = torch.distributed.ReduceOp.SUM if not mean else torch.distributed.ReduceOp.AVG
torch.distributed.reduce(loss, dst_rank, op, group=None)
# Return loss if dst_rank, None otherwise
if distmng.rank == dst_rank:
return float(loss.cpu())
else:
return None# distributed primitives
[docs]def distributed_transpose(tensor, dim0, dim1, group=None, async_op=False):
"""Perform distributed transpose of tensor to switch sharding dimension"""
# get input format
input_format = get_memory_format(tensor)
# get comm params
comm_size = dist.get_world_size(group=group)
# split and local transposition
split_size = tensor.shape[dim0] // comm_size
x_send = [
y.contiguous(memory_format=input_format)
for y in torch.split(tensor, split_size, dim=dim0)
]
x_recv = [torch.empty_like(x_send[0]) for _ in range(comm_size)]
# global transposition
req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
return x_recv, reqdef _reduce(input_, use_fp32=True, group=None): # pragma: no cover
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# All-reduce, use_fp32 only relevant for lower precisions
# if input is already in double precision, nothing changes
if use_fp32 and (input_.dtype.itemsize < 4) and input_.dtype.is_floating_point:
dtype = input_.dtype
inputf_ = input_.float()
dist.all_reduce(inputf_, group=group)
input_ = inputf_.to(dtype)
else:
dist.all_reduce(input_, group=group)
return input_
def _split(input_, dim_, group=None): # pragma: no cover
"""Split the tensor along its last dimension and keep the corresponding slice."""
# get input format
input_format = get_memory_format(input_)
# Bypass the function if we are using only 1 GPU.
comm_size = dist.get_world_size(group=group)
if comm_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_dim(input_, dim_, comm_size)
# Note: torch.split does not create contiguous tensors by default.
rank = dist.get_rank(group=group)
output = input_list[rank].contiguous(memory_format=input_format)
return output
[docs]def all_gather_v_wrapper(
tensor: torch.Tensor,
sizes: Optional[List[int]] = None,
dim: int = 0,
group: Optional[dist.ProcessGroup] = None,
) -> torch.Tensor: # pragma: no cover
"""
Implements a distributed AllGatherV primitive. It is based
on the idea of a single global tensor which is distributed along
a specified dimension into chunks of variable size.
This primitive gathers all local tensors from each rank into the
full global tensor onto each rank.
Parameters
----------
tensor : "torch.Tensor"
local tensor on each rank
sizes : List[int], optional
list of the sizes of each chunk on each rank along distributed dimension,
valid and set on each rank, by default None
dim : int, optional
dimension along which global tensor is distributed, by default 0
group : Optional[dist.ProcessGroup], optional
process group along which global tensor is shared, by default None
Returns
-------
torch.Tensor
full global tensor, valid on each rank
"""
comm_size = dist.get_world_size(group=group)
if (sizes is not None) and (len(sizes) != comm_size):
raise ValueError()
if dim >= tensor.dim():
raise ValueError()
if comm_size == 1:
return tensor
tensor_shape = list(tensor.shape)
tensor_format = get_memory_format(tensor)
if sizes is not None:
tensor_list = [None] * comm_size
for src in range(comm_size):
tensor_shape[dim] = sizes[src]
tensor_list[src] = torch.empty(
tensor_shape,
dtype=tensor.dtype,
device=tensor.device,
)
else:
# assume equal shape on all ranks
tensor_list = [torch.empty_like(tensor) for _ in range(comm_size)]
dist.all_gather(tensor_list, tensor, group=group)
output = torch.cat(tensor_list, dim=dim).contiguous(memory_format=tensor_format)
return output
[docs]def all_gather_v_bwd_wrapper(
tensor: torch.Tensor,
sizes: List[int],
dim: int = 0,
use_fp32: bool = True,
group: Optional[dist.ProcessGroup] = None,
) -> torch.Tensor: # pragma: no cover
"""
Implements a distributed AllReduceV primitive. It is based
on the idea of a single global tensor which which can be distributed
along a specified dimension into chunks of variable size.
This primitive assumes different global tensors of the same shape on each
rank. It then re-distributes chunks of all these tensors such that each rank
receives all corresponding parts of a global tensor. Each rank then sums up
the chunks after receiving it. By design, this primitive thus implements the
backward pass of the "all_gather_v" primitive. In this case, the result would
be a single global gradient tensor distributed onto different ranks.
Parameters
----------
tensor : torch.Tensor
global tensor on each rank (different one on each rank)
sizes : List[int]
list of the sizes of each chunk on each rank along distributed dimension,
valid and set on each rank
dim : int, optional
dimension along which global tensor is distributed, by default 0
use_fp32 : bool, optional
flag to specify reduction taking place at least in FP32 precision, by default True
only acts on floating point inputs in lower precision
group : Optional[dist.ProcessGroup], optional
process group along which global tensor is shared, by default None
Returns
-------
torch.Tensor
local tensor, i.e. result of reduction of all corresponding chunks
from all global tensors for each rank separately
"""
comm_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
if len(sizes) != comm_size:
raise ValueError()
if dim >= tensor.dim():
raise ValueError()
tensor_shape = list(tensor.shape)
tensor_shape[dim] = sizes[rank]
tmp = [
torch.empty(
tensor_shape,
dtype=tensor.dtype,
device=tensor.device,
)
for _ in range(comm_size)
]
scatter_list = list(torch.split(tensor, sizes, dim=dim))
scatter_list = [t.contiguous() for t in scatter_list]
dist.all_to_all(tmp, scatter_list, group=group)
stack_dim = tensor.dim()
tmp = torch.stack(tmp, dim=stack_dim)
if use_fp32 and (tmp.dtype.itemsize < 4) and tmp.dtype.is_floating_point:
# cast to float before sum and return float, then cast back
output = tmp.sum(dim=stack_dim, dtype=torch.float32)
output = output.to(dtype=tensor.dtype)
else:
# else: just do sum in native dtype
output = tmp.sum(dim=stack_dim)
return output
[docs]def gather_v_wrapper(
tensor: torch.Tensor,
sizes: List[int],
dim: int = 0,
dst: int = 0,
group: Optional[dist.ProcessGroup] = None,
) -> torch.Tensor: # pragma: no cover
"""
Implements a distributed GatherV primitive. It is based
on the idea of a single global tensor which is distributed along
a specified dimension into chunks of variable size.
This primitive assumes such a distributed tensor and gathers all
local tensors from each rank into the full global tensor valid
on the specified destination rank.
Parameters
----------
tensor : torch.Tensor
local tensor on each rank
sizes : List[int]
list of the sizes of each chunk on each rank along distributed dimension,
valid and set on each rank
dim : int, optional
dimension along which global tensor is distributed, by default 0
dst : int, optional
destination rank which contains the full global tensor after the
operation, by default 0
group : Optional[dist.ProcessGroup], optional
process group along which global tensor is shared, by default None
Returns
-------
torch.Tensor
full global tensor, valid on destination rank
"""
comm_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
if len(sizes) != comm_size:
raise ValueError()
if dim >= tensor.dim():
raise ValueError()
if not (0 <= dst < comm_size):
raise ValueError()
if tensor.size(dim) != sizes[rank]:
raise ValueError()
if comm_size == 1:
return tensor
tensor_shape = list(tensor.shape)
x_recv = [None] * comm_size
x_send = [None] * comm_size
for r in range(comm_size):
if rank == dst:
tensor_shape[dim] = sizes[r]
else:
tensor_shape[dim] = 0
x_recv[r] = torch.empty(
tensor_shape,
dtype=tensor.dtype,
device=tensor.device,
)
if r == dst:
x_send[r] = tensor
else:
tensor_shape[dim] = 0
x_send[r] = torch.empty(
tensor_shape,
dtype=tensor.dtype,
device=tensor.device,
)
dist.all_to_all(x_recv, x_send, group=group)
# TODO: clean gather/scatter and some examples up
# main question is around whether e.g. gather returns
# None for rank != dst or an empty dummy or an dummy
# containing meta-information like dtype/etc..
if rank != dst:
for r in range(comm_size):
tensor_shape[dim] = sizes[r]
x_recv[r] = torch.empty(
tensor_shape,
dtype=tensor.dtype,
device=tensor.device,
)
output = torch.cat(x_recv, dim=dim)
return output
[docs]def scatter_v_wrapper(
tensor: torch.Tensor,
sizes: List[int],
dim: int = 0,
src: int = 0,
group: Optional[dist.ProcessGroup] = None,
) -> torch.Tensor: # pragma: no cover
"""
Implements a distributed ScatterV primitive. It is based
on the idea of a single global tensor which is distributed along
a specified dimension into chunks of variable size.
This primitive scatters the global tensor from a specified source rank
into local chunks onto each other rank.
Parameters
----------
tensor : torch.Tensor
global tensor, valid on source rank
sizes : List[int]
list of the sizes of each chunk on each rank along distributed dimension,
valid and set each rank
dim : int, optional
dimension along which global tensor is distributed, by default 0
src : int, optional
source rank of primitive, i.e. rank of original full global tensor, by default 0
group : Optional[dist.ProcessGroup], optional
process group along which global tensor is shared, by default None
Returns
-------
torch.Tensor
corresponding local part of the global tensor on each rank
"""
comm_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
if len(sizes) != comm_size:
raise ValueError()
if dist.get_rank(group=group) == 0 and dim >= tensor.dim():
raise ValueError()
if not (0 <= src < comm_size):
raise ValueError()
# all_to_all is already all_to_all_v, use empty tensors to "mask"-out irrelevant parts
tensor_shape = list(tensor.shape)
x_send = [None] * comm_size
x_recv = [None] * comm_size
if rank == src:
scatter_list = torch.split(tensor, sizes, dim=dim)
scatter_list = [t.contiguous() for t in scatter_list]
x_send = scatter_list
else:
for r in range(comm_size):
tensor_shape[dim] = 0
x_send[r] = torch.empty(
tensor_shape, device=tensor.device, dtype=tensor.dtype
)
for r in range(comm_size):
if r == src:
tensor_shape[dim] = sizes[rank]
else:
tensor_shape[dim] = 0
x_recv[r] = torch.empty(tensor_shape, device=tensor.device, dtype=tensor.dtype)
dist.all_to_all(x_recv, x_send, group=group)
return x_recv[src]
[docs]def indexed_all_to_all_v_wrapper(
tensor: torch.Tensor,
indices: List[torch.Tensor],
sizes: List[List[int]],
dim: int = 0,
group: Optional[dist.ProcessGroup] = None,
) -> torch.Tensor: # pragma: no cover
"""
Implements an indexed version of a distributed AllToAllV
primitive. It is based on the idea of a single global tensor which
is distributed along a specified dimension into chunks of variable size.
This primitive assumes a set of indices into this dimension which indicate
the corresponding slices sent to each other rank forming an indexed version
of an AllToAllV primitive.
Parameters
----------
tensor : torch.Tensor
local part of global tensor on each rank
indices : List[torch.Tensor]
list of indices on each rank of slices being sent to
each other rank from this rank
sizes : List[List[int]]
number of indices each rank sends to each other rank,
valid and set on each rank, e.g. sizes[0][3] corresponds
to the number of slices rank 0 sends to rank 3
dim : int
dimension along which global tensor is distributed, by default 0
group : Optional[dist.ProcessGroup], optional
process group along which global tensor is shared, by default None
Returns
-------
torch.Tensor
local result of primitive corresponding to indexed global tensor
"""
comm_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
if len(sizes) != comm_size:
raise ValueError()
if dim >= tensor.dim():
raise ValueError()
if len(sizes[rank]) != comm_size:
raise ValueError()
if len(indices) != comm_size:
raise ValueError()
x_send = [tensor[idx] for idx in indices]
x_recv = [None] * comm_size
tensor_shape = list(tensor.shape)
for r in range(comm_size):
tensor_shape[dim] = sizes[r][rank]
x_recv[r] = torch.empty(
tensor_shape,
dtype=tensor.dtype,
device=tensor.device,
)
dist.all_to_all(x_recv, x_send, group=group)
tensor_to_recv = torch.cat(x_recv, dim=dim)
return tensor_to_recv
[docs]def indexed_all_to_all_v_wrapper_bwd(
tensor: torch.Tensor,
indices: List[torch.Tensor],
sizes: List[List[int]],
tensor_size_along_dim: int,
use_fp32: bool = True,
dim: int = 0,
group: Optional[dist.ProcessGroup] = None,
) -> torch.Tensor: # pragma: no cover
"""
Implements the backward pass to the indexed version of a distributed
AllToAllV primitive.
Parameters
----------
tensor : torch.Tensor
local tensor, i.e. gradient on resulting tensor from forward pass
indices : List[torch.Tensor]
list of indices on each rank of slices being sent to
each other rank from this rank
sizes : List[List[int]]
list of the sizes of each chunk on each rank along distributed dimension,
valid and set on each rank
tensor_size_along_dim : int
size of original local tensor along specified dimension,
i.e. from the corresponding forward pass
use_fp32 : bool, optional
flag to specify reduction taking place at least in FP32 precision, by default True
only acts on floating point inputs in lower precision
dim : int, optional
dimension along with global tensor is distributed, by default 0
group : Optional[dist.ProcessGroup], optional
process group along which global tensor is shared, by default None
Returns
-------
torch.Tensor
result of primitive corresponding to indexed global tensor
"""
comm_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
if len(sizes) != comm_size:
raise ValueError()
if dim >= tensor.dim():
raise ValueError()
if len(sizes[rank]) != comm_size:
raise ValueError()
if len(indices) != comm_size:
raise ValueError()
tensor_shape = list(tensor.shape)
# scatter gradients, roles reversed compared to forward pass
# recv_sizes in forward pass
recv_sizes = [sizes[i][rank] for i in range(comm_size)]
# send_sizes in forward pass
send_sizes = [sizes[rank][i] for i in range(comm_size)]
x_send = torch.split(tensor, recv_sizes, dim=dim)
x_send = [t.contiguous() for t in x_send]
x_recv = [None] * comm_size
for r in range(comm_size):
tensor_shape[dim] = send_sizes[r]
x_recv[r] = torch.empty(
tensor_shape,
dtype=tensor.dtype,
device=tensor.device,
)
dist.all_to_all(x_recv, x_send, group=group)
tensor_to_recv = torch.cat(x_recv, dim=dim)
# sum up gathered gradients and taking
# care of precision handling as specified
# by boolean flag
indices = torch.cat(indices, dim=0)
tensor_shape[dim] = tensor_size_along_dim
if use_fp32 and (tensor.dtype.itemsize < 4) and tensor.dtype.is_floating_point:
out = torch.zeros(
tensor_shape,
dtype=torch.float32,
device=tensor.device,
)
tensor_to_recv = tensor_to_recv.to(dtype=torch.float32)
else:
out = torch.zeros(
tensor_shape,
dtype=tensor.dtype,
device=tensor.device,
)
out.index_add_(source=tensor_to_recv, index=indices, dim=dim)
if out.dtype != tensor.dtype:
out = out.to(tensor.dtype)
return out