deeplearning/modulus/modulus-core/_modules/modulus/distributed/utils.html

Source code for modulus.distributed.utils

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO this also needs more docstrings
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from .manager import DistributedManager


def compute_split_shapes(size: int, num_chunks: int) -> List[int]:

    # treat trivial case first
    if num_chunks == 1:
        return [size]

    # first, check if we can split using div-up to balance the load:
    chunk_size = (size + num_chunks - 1) // num_chunks
    last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
    if last_chunk_size == 0:
        # in this case, the last shard would be empty, split with floor instead:
        chunk_size = size // num_chunks
        last_chunk_size = size - chunk_size * (num_chunks - 1)

    # generate sections list
    sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]

    return sections


[docs]def get_memory_format(tensor): """Gets format for tensor""" if tensor.is_contiguous(memory_format=torch.channels_last): return torch.channels_last else: return torch.contiguous_format
[docs]def pad_helper(tensor, dim, new_size, mode="zero"): """Util for padding tensors""" ndim = tensor.ndim dim = (dim + ndim) % ndim ndim_pad = ndim - dim output_shape = [0 for _ in range(2 * ndim_pad)] orig_size = tensor.shape[dim] output_shape[1] = new_size - orig_size tensor_pad = F.pad(tensor, output_shape, mode="constant", value=0.0) if mode == "conj": lhs_slice = [ slice(0, x) if idx != dim else slice(orig_size, new_size) for idx, x in enumerate(tensor.shape) ] rhs_slice = [ slice(0, x) if idx != dim else slice(1, output_shape[1] + 1) for idx, x in enumerate(tensor.shape) ] tensor_pad[lhs_slice] = torch.flip( torch.conj(tensor_pad[rhs_slice]), dims=[dim] ) return tensor_pad
[docs]def truncate_helper(tensor, dim, new_size): """Util for truncating""" input_format = get_memory_format(tensor) ndim = tensor.ndim dim = (dim + ndim) % ndim output_slice = [ slice(0, x) if idx != dim else slice(0, new_size) for idx, x in enumerate(tensor.shape) ] tensor_trunc = tensor[output_slice].contiguous(memory_format=input_format) return tensor_trunc

def split_tensor_along_dim(tensor, dim, num_chunks): if dim >= tensor.dim(): raise ValueError( f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" ) if tensor.shape[dim] < num_chunks: raise ValueError( "Error, cannot split dim {dim} of size {tensor.shape[dim]} into \ {num_chunks} chunks. Empty slices are currently not supported." ) # get split sections = compute_split_shapes(tensor.shape[dim], num_chunks) tensor_list = torch.split(tensor, sections, dim=dim) return tensor_list

[docs]@torch.no_grad() def reduce_loss(loss: float, dst_rank: int = 0, mean: bool = True): # pragma: no cover """Reduces loss from all processes to destination rank for logging. Parameters ---------- loss : float loss value dst_rank : int, Optional destination rank to redce to, by default 0. mean : bool, Optional Calculate the mean of the losses gathered, by default True. Raises ------ Exception If DistributedManager has yet to be initialized """ if not DistributedManager.is_initialized(): raise Exception( "Distributed manager should be initialized when using reduce_loss" ) distmng = DistributedManager() loss = torch.Tensor([loss]).to(distmng.device) # For serial runs, just return the current loss! if distmng.world_size == 1: return float(loss) op = torch.distributed.ReduceOp.SUM if not mean else torch.distributed.ReduceOp.AVG torch.distributed.reduce(loss, dst_rank, op, group=None) # Return loss if dst_rank, None otherwise if distmng.rank == dst_rank: return float(loss.cpu()) else: return None

# distributed primitives

[docs]def distributed_transpose(tensor, dim0, dim1, group=None, async_op=False): """Perform distributed transpose of tensor to switch sharding dimension""" # get input format input_format = get_memory_format(tensor) # get comm params comm_size = dist.get_world_size(group=group) # split and local transposition split_size = tensor.shape[dim0] // comm_size x_send = [ y.contiguous(memory_format=input_format) for y in torch.split(tensor, split_size, dim=dim0) ] x_recv = [torch.empty_like(x_send[0]) for _ in range(comm_size)] # global transposition req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op) return x_recv, req

def _reduce(input_, use_fp32=True, group=None): # pragma: no cover """All-reduce the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. if dist.get_world_size(group=group) == 1: return input_ # All-reduce. if use_fp32: dtype = input_.dtype inputf_ = input_.float() dist.all_reduce(inputf_, group=group) input_ = inputf_.to(dtype) else: dist.all_reduce(input_, group=group) return input_ def _split(input_, dim_, group=None): # pragma: no cover """Split the tensor along its last dimension and keep the corresponding slice.""" # get input format input_format = get_memory_format(input_) # Bypass the function if we are using only 1 GPU. comm_size = dist.get_world_size(group=group) if comm_size == 1: return input_ # Split along last dimension. input_list = split_tensor_along_dim(input_, dim_, comm_size) # Note: torch.split does not create contiguous tensors by default. rank = dist.get_rank(group=group) output = input_list[rank].contiguous(memory_format=input_format) return output

[docs]def all_gather_v_wrapper( tensor: torch.Tensor, sizes: Optional[List[int]] = None, dim: int = 0, group: Optional[dist.ProcessGroup] = None, ) -> torch.Tensor: # pragma: no cover """ Implements a distributed AllGatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive gathers all local tensors from each rank into the full global tensor onto each rank. Parameters ---------- tensor : "torch.Tensor" local tensor on each rank sizes : List[int], optional list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank, by default None dim : int, optional dimension along which global tensor is distributed, by default 0 group : Optional[dist.ProcessGroup], optional process group along which global tensor is shared, by default None Returns ------- torch.Tensor full global tensor, valid on each rank """ comm_size = dist.get_world_size(group=group) if (sizes is not None) and (len(sizes) != comm_size): raise ValueError() if dim >= tensor.dim(): raise ValueError() if comm_size == 1: return tensor tensor_shape = list(tensor.shape) tensor_format = get_memory_format(tensor) if sizes is not None: tensor_list = [None] * comm_size for src in range(comm_size): tensor_shape[dim] = sizes[src] tensor_list[src] = torch.empty( tensor_shape, dtype=tensor.dtype, device=tensor.device, ) else: # assume equal shape on all ranks tensor_list = [torch.empty_like(tensor) for _ in range(comm_size)] dist.all_gather(tensor_list, tensor, group=group) output = torch.cat(tensor_list, dim=dim).contiguous(memory_format=tensor_format) return output
[docs]def all_gather_v_bwd_wrapper( tensor: torch.Tensor, sizes: List[int], dim: int = 0, use_fp32: bool = True, group: Optional[dist.ProcessGroup] = None, ) -> torch.Tensor: # pragma: no cover """ Implements a distributed AllReduceV primitive. It is based on the idea of a single global tensor which which can be distributed along a specified dimension into chunks of variable size. This primitive assumes different global tensors of the same shape on each rank. It then re-distributes chunks of all these tensors such that each rank receives all corresponding parts of a global tensor. Each rank then sums up the chunks after receiving it. By design, this primitive thus implements the backward pass of the "all_gather_v" primitive. In this case, the result would be a single global gradient tensor distributed onto different ranks. Parameters ---------- tensor : torch.Tensor global tensor on each rank (different one on each rank) sizes : List[int] list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank dim : int, optional dimension along which global tensor is distributed, by default 0 use_fp32 : bool, optional flag to specify FP32 precision for the redcution, by default True group : Optional[dist.ProcessGroup], optional process group along which global tensor is shared, by default None Returns ------- torch.Tensor local tensor, i.e. result of reduction of all corresponding chunks from all global tensors for each rank separately """ comm_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) if len(sizes) != comm_size: raise ValueError() if dim >= tensor.dim(): raise ValueError() tensor_shape = list(tensor.shape) tensor_shape[dim] = sizes[rank] tmp = [ torch.empty( tensor_shape, dtype=tensor.dtype, device=tensor.device, ) for _ in range(comm_size) ] scatter_list = list(torch.split(tensor, sizes, dim=dim)) dist.all_to_all(tmp, scatter_list, group=group) stack_dim = tensor.dim() tmp = torch.stack(tmp, dim=stack_dim) if use_fp32: # cast to float before sum and return float, then cast back output = tmp.sum(dim=stack_dim, dtype=torch.float32) output = output.to(dtype=tensor.dtype) else: # else: just do sum in native dtype output = tmp.sum(dim=stack_dim) return output
[docs]def gather_v_wrapper( tensor: torch.Tensor, sizes: List[int], dim: int = 0, dst: int = 0, group: Optional[dist.ProcessGroup] = None, ) -> torch.Tensor: # pragma: no cover """ Implements a distributed GatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes such a distributed tensor and gathers all local tensors from each rank into the full global tensor valid on the specified destination rank. Parameters ---------- tensor : torch.Tensor local tensor on each rank sizes : List[int] list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank dim : int, optional dimension along which global tensor is distributed, by default 0 dst : int, optional destination rank which contains the full global tensor after the operation, by default 0 group : Optional[dist.ProcessGroup], optional process group along which global tensor is shared, by default None Returns ------- torch.Tensor full global tensor, valid on destination rank """ comm_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) if len(sizes) != comm_size: raise ValueError() if dim >= tensor.dim(): raise ValueError() if not (0 <= dst < comm_size): raise ValueError() if tensor.size(dim) != sizes[rank]: raise ValueError() if comm_size == 1: return tensor gather_list = [None] * comm_size tensor_shape = list(tensor.shape) for r in range(comm_size): tensor_shape[dim] = sizes[r] gather_list[r] = torch.empty( tensor_shape, dtype=tensor.dtype, device=tensor.device, ) # dist.scatter doesn't support tensors of different shape # so this implementation is using explicit send/recv combinations if rank == dst: req_list = [None] * comm_size for r in range(comm_size): if r == dst: gather_list[r] = tensor else: req_list[r] = dist.irecv(gather_list[r], src=r, group=group) for r in range(comm_size): if r != dst: req_list[r].wait() else: req = dist.isend(tensor, dst=dst, group=group) req.wait() output = torch.cat(gather_list, dim=dim) return output
[docs]def scatter_v_wrapper( tensor: torch.Tensor, sizes: List[int], dim: int = 0, src: int = 0, group: Optional[dist.ProcessGroup] = None, ) -> torch.Tensor: # pragma: no cover """ Implements a distributed ScatterV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive scatters the global tensor from a specified source rank into local chunks onto each other rank. Parameters ---------- tensor : torch.Tensor global tensor, valid on source rank sizes : List[int] list of the sizes of each chunk on each rank along distributed dimension, valid and set each rank dim : int, optional dimension along which global tensor is distributed, by default 0 src : int, optional source rank of primitive, i.e. rank of original full global tensor, by default 0 group : Optional[dist.ProcessGroup], optional process group along which global tensor is shared, by default None Returns ------- torch.Tensor corresponding local part of the global tensor on each rank """ comm_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) if len(sizes) != comm_size: raise ValueError() if dim >= tensor.dim(): raise ValueError() if not (0 <= src < comm_size): raise ValueError() tensor_shape = list(tensor.shape) tensor_shape[dim] = sizes[rank] output = torch.empty( tensor_shape, dtype=tensor.dtype, device=tensor.device, ) # dist.scatter doesn't support tensors of different shape # so this implementation is using explicit send/recv combinations scatter_list = None if rank == src: scatter_list = torch.split(tensor, sizes, dim=dim) req_list = [None] * comm_size for r in range(comm_size): tensor_to_scatter_to_r = scatter_list[r] if r == src: output = tensor_to_scatter_to_r else: req_list[r] = dist.isend(tensor_to_scatter_to_r, dst=r, group=group) for r in range(comm_size): if r != src: req_list[r].wait() else: req = dist.irecv(output, src=src, group=group) req.wait() return output
[docs]def indexed_all_to_all_v_wrapper( tensor: torch.Tensor, indices: List[torch.Tensor], sizes: List[List[int]], dim: int = 0, group: Optional[dist.ProcessGroup] = None, ) -> torch.Tensor: # pragma: no cover """ Implements an indexed version of a distributed AllToAllV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes a set of indices into this dimension which indicate the corresponding slices sent to each other rank forming an indexed version of an AllToAllV primitive. Parameters ---------- tensor : torch.Tensor local part of global tensor on each rank indices : List[torch.Tensor] list of indices on each rank of slices being sent to each other rank from this rank sizes : List[List[int]] number of indices each rank sends to each other rank, valid and set on each rank, e.g. sizes[0][3] corresponds to the number of slices rank 0 sends to rank 3 dim : int dimension along which global tensor is distributed, by default 0 group : Optional[dist.ProcessGroup], optional process group along which global tensor is shared, by default None Returns ------- torch.Tensor local result of primitive corresponding to indexed global tensor """ comm_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) if len(sizes) != comm_size: raise ValueError() if dim >= tensor.dim(): raise ValueError() if len(sizes[rank]) != comm_size: raise ValueError() if len(indices) != comm_size: raise ValueError() indices = torch.cat(indices, dim=0) tensor_to_send = torch.index_select(tensor, dim=dim, index=indices) recv_list = [None] * comm_size for r in range(comm_size): recv_list[r] = scatter_v_wrapper( tensor_to_send, sizes=sizes[r], src=r, dim=dim, group=group, ) tensor_to_recv = torch.cat(recv_list, dim=dim) return tensor_to_recv
[docs]def indexed_all_to_all_v_wrapper_bwd( tensor: torch.Tensor, indices: List[torch.Tensor], sizes: List[List[int]], tensor_size_along_dim: int, use_fp32: bool = True, dim: int = 0, group: Optional[dist.ProcessGroup] = None, ) -> torch.Tensor: # pragma: no cover """ Implements the backward pass to the indexed version of a distributed AllToAllV primitive. Parameters ---------- tensor : torch.Tensor local tensor, i.e. gradient on resulting tensor from forward pass indices : List[torch.Tensor] list of indices on each rank of slices being sent to each other rank from this rank sizes : List[List[int]] list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank tensor_size_along_dim : int size of original local tensor along specified dimension, i.e. from the corresponding forward pass use_fp32 : bool, optional flag to specify FP32 precision, by default True dim : int, optional dimension along with global tensor is distributed, by default 0 group : Optional[dist.ProcessGroup], optional process group along which global tensor is shared, by default None Returns ------- torch.Tensor result of primitive corresponding to indexed global tensor """ comm_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) if len(sizes) != comm_size: raise ValueError() if dim >= tensor.dim(): raise ValueError() if len(sizes[rank]) != comm_size: raise ValueError() if len(indices) != comm_size: raise ValueError() indices = torch.cat(indices, dim=0) tensor_shape = list(tensor.shape) # scatter gradients, roles reversed compared to forward pass recv_list = [None] * comm_size for r in range(comm_size): recv_sizes = [sizes[i][r] for i in range(comm_size)] recv_list[r] = scatter_v_wrapper( tensor, recv_sizes, dim=dim, src=r, group=group ) tensor_to_recv = torch.cat(recv_list, dim=dim) # sum up gathered gradients and taking # care of precision handling as specified # by boolean flag tensor_shape[dim] = tensor_size_along_dim if use_fp32: out = torch.zeros( tensor_shape, dtype=torch.float32, device=tensor.device, ) tensor_to_recv = tensor_to_recv.to(dtype=torch.float32) else: out = torch.zeros( tensor_shape, dtype=tensor.dtype, device=tensor.device, ) out.index_add_(source=tensor_to_recv, index=indices, dim=dim) if use_fp32: out = out.to(tensor.dtype) return out
[docs]def mark_module_as_shared( module: nn.Module, process_group: Optional[str], recurse: bool = True, use_fp32_reduction: bool = True, ) -> nn.Module: """ Helper function to mark parameters of a module as being shared across ranks by attaching gradient hooks to the corresponding tensors. Parameters ---------- module : nn.Module PyTorch module which is to be marked as having shared parameters. process_group : str | None str indicating process_group which contains ranks across which the module's parameters are shared. If passed as None, will default to the world group. recurse : bool, default=True Flag indicating whether the module's parameters are traversed in a recursive fashion, i.e. whether sub-modules are also considered as having shared parameters. use_fp32_reduction : bool, default=True Flag indicating whether the reduction for accumulating gradients will be done in FP32 or the native datatype. """ group = DistributedManager().group(process_group) handle_key = "_shared_weight_dist_hook" def hook(grad: torch.Tensor) -> torch.Tensor: # the documentation states that # "The hook should not modify its argument, but it can optionally return a new gradient # which will be used in place of grad." # as all_reduce is an in-place operation, need to copy gradient grad = _reduce(grad.clone(), group=group, use_fp32=use_fp32_reduction) return grad def hook_post_accum(param: torch.Tensor) -> None: # the documentation states that # "Note that, unlike other autograd hooks, this hook operates on the tensor that requires grad # and not the grad itself. The hook can in-place modify and access its Tensor argument, # including its .grad field." param.grad = _reduce(param.grad, group=group, use_fp32=use_fp32_reduction) for name, param in module.named_parameters(recurse=recurse): error_msg = f"Parameter {name} already marked as having shared weights, can't mark it again!" if hasattr(param, handle_key): raise RuntimeError(error_msg) if torch.__version__ < (2, 1): handle = param.register_hook(hook) else: handle = param.register_post_accumulate_grad_hook(hook_post_accum) setattr(param, handle_key, handle) return module
[docs]def unmark_module_as_shared( module: nn.Module, recurse: bool = True, ) -> nn.Module: """ Helper function to unmark parameters of a module as being shared across ranks by removing attached gradient hooks. Parameters ---------- module : nn.Module PyTorch module which is to be unmarked as having shared parameters. recurse : bool, default=True Flag indicating whether the module's parameters are traversed in a recursive fashion, i.e. whether sub-modules are also considered as having shared parameters. """ handle_key = "_shared_weight_dist_hook" for name, param in module.named_parameters(recurse=recurse): error_msg = ( f"Parameter {name} NOT marked as having shared weights, can't unmark it!" ) if not hasattr(param, handle_key): raise RuntimeError(error_msg) handle = getattr(param, handle_key) handle.remove() delattr(param, handle_key) return module
© Copyright 2023, NVIDIA Modulus Team. Last updated on Apr 19, 2024.