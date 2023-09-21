# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # TODO this also needs more docstrings import torch import torch.nn.functional as F import torch.distributed as dist from .manager import DistributedManager # def get_memory_format(tensor): # """Gets format for tensor""" # if tensor.is_contiguous(memory_format=torch.channels_last): # return torch.channels_last # else: # return torch.contiguous_format # def pad_helper(tensor, dim, new_size, mode="zero"): # """Util for padding tensors""" # ndim = tensor.ndim # dim = (dim + ndim) % ndim # ndim_pad = ndim - dim # output_shape = [0 for _ in range(2 * ndim_pad)] # orig_size = tensor.shape[dim] # output_shape[1] = new_size - orig_size # tensor_pad = F.pad(tensor, output_shape, mode="constant", value=0.0) # if mode == "conj": # lhs_slice = [ # slice(0, x) if idx != dim else slice(orig_size, new_size) # for idx, x in enumerate(tensor.shape) # ] # rhs_slice = [ # slice(0, x) if idx != dim else slice(1, output_shape[1] + 1) # for idx, x in enumerate(tensor.shape) # ] # tensor_pad[lhs_slice] = torch.flip( # torch.conj(tensor_pad[rhs_slice]), dims=[dim] # ) # return tensor_pad # def truncate_helper(tensor, dim, new_size): # """Util for truncating""" # input_format = get_memory_format(tensor) # ndim = tensor.ndim # dim = (dim + ndim) % ndim # output_slice = [ # slice(0, x) if idx != dim else slice(0, new_size) # for idx, x in enumerate(tensor.shape) # ] # tensor_trunc = tensor[output_slice].contiguous(memory_format=input_format) # return tensor_trunc # def split_tensor_along_dim(tensor, dim, num_chunks): # """splits tensor along specific dim""" # assert ( # dim < tensor.dim() # ), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" # assert ( # tensor.shape[dim] % num_chunks == 0 # ), f"Error, cannot split dim {dim} evenly. Dim size is \ # {tensor.shape[dim]} and requested numnber of splits is {num_chunks}" # chunk_size = tensor.shape[dim] // num_chunks # tensor_list = torch.split(tensor, chunk_size, dim=dim) # return tensor_list [docs] @torch . no_grad () def gather_loss ( loss : float , dst_rank : int = 0 , mean : bool = True ): """Gathers loss from all processes to one for logging Parameters ---------- loss : float loss value dst_rank : int, optional destination rank to gather to, by default 0 mean : bool, optional Calculate the mean of the losses gathered, by default True Raises ------ Exception If DistributedManager has yet to be initialized """ if not DistributedManager . is_initialized (): raise Exception ( "Distributed manager should be initialized when using gather_loss" ) distmng = DistributedManager () loss = torch . Tensor ([ loss ]) # For serial runs, just return the current loss! if distmng . world_size == 1 : return float ( loss ) # Gather using PyTorch distributed function gather_list = None if distmng . rank == dst_rank : gather_list = [ torch . zeros ( 1 ) . to ( distmng . device ) for i in range ( distmng . world_size ) ] dist . gather ( loss . to ( distmng . device ), gather_list , dst_rank ) # Return loss if dst_rank, None otherwise if distmng . rank == dst_rank : loss = torch . sum ( torch . cat ( gather_list )) if mean : loss = loss / distmng . world_size return float ( loss . cpu ()) else : return None # # distributed primitives # def _transpose(tensor, dim0, dim1, group=None, async_op=False): # # get input format # input_format = get_memory_format(tensor) # # get comm params # comm_size = dist.get_world_size(group=group) # # split and local transposition # split_size = tensor.shape[dim0] // comm_size # x_send = [ # y.contiguous(memory_format=input_format) # for y in torch.split(tensor, split_size, dim=dim0) # ] # x_recv = [torch.empty_like(x_send[0]) for _ in range(comm_size)] # # global transposition # req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op) # return x_recv, req # def _reduce(input_, use_fp32=True, group=None): # """All-reduce the input tensor across model parallel group.""" # # Bypass the function if we are using only 1 GPU. # if dist.get_world_size(group=group) == 1: # return input_ # # All-reduce. # if use_fp32: # dtype = input_.dtype # inputf_ = input_.float() # dist.all_reduce(inputf_, group=group) # input_ = inputf_.to(dtype) # else: # dist.all_reduce(input_, group=group) # return input_ # def _split(input_, dim_, group=None): # """Split the tensor along its last dimension and keep the corresponding slice.""" # # get input format # input_format = get_memory_format(input_) # # Bypass the function if we are using only 1 GPU. # comm_size = dist.get_world_size(group=group) # if comm_size == 1: # return input_ # # Split along last dimension. # input_list = split_tensor_along_dim(input_, dim_, comm_size) # # Note: torch.split does not create contiguous tensors by default. # rank = dist.get_rank(group=group) # output = input_list[rank].contiguous(memory_format=input_format) # return output # def _gather(input_, dim_, group=None): # """Gather tensors and concatinate along the last dimension.""" # # get input format # input_format = get_memory_format(input_) # comm_size = dist.get_world_size(group=group) # # Bypass the function if we are using only 1 GPU. # if comm_size == 1: # return input_ # # sanity checks # assert ( # dim_ < input_.dim() # ), f"Error, cannot gather along {dim_} for tensor with {input_.dim()} dimensions." # # Size and dimension. # comm_rank = dist.get_rank(group=group) # tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] # tensor_list[comm_rank] = input_ # dist.all_gather(tensor_list, input_, group=group) # # Note: torch.cat already creates a contiguous tensor. # output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format) # return output