Source code for modulus.distributed.utils

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO this also needs more docstrings
import torch
import torch.nn.functional as F
import torch.distributed as dist
from .manager import DistributedManager


# def get_memory_format(tensor):
#     """Gets format for tensor"""
#     if tensor.is_contiguous(memory_format=torch.channels_last):
#         return torch.channels_last
#     else:
#         return torch.contiguous_format


# def pad_helper(tensor, dim, new_size, mode="zero"):
#     """Util for padding tensors"""
#     ndim = tensor.ndim
#     dim = (dim + ndim) % ndim
#     ndim_pad = ndim - dim
#     output_shape = [0 for _ in range(2 * ndim_pad)]
#     orig_size = tensor.shape[dim]
#     output_shape[1] = new_size - orig_size
#     tensor_pad = F.pad(tensor, output_shape, mode="constant", value=0.0)

#     if mode == "conj":
#         lhs_slice = [
#             slice(0, x) if idx != dim else slice(orig_size, new_size)
#             for idx, x in enumerate(tensor.shape)
#         ]
#         rhs_slice = [
#             slice(0, x) if idx != dim else slice(1, output_shape[1] + 1)
#             for idx, x in enumerate(tensor.shape)
#         ]
#         tensor_pad[lhs_slice] = torch.flip(
#             torch.conj(tensor_pad[rhs_slice]), dims=[dim]
#         )

#     return tensor_pad


# def truncate_helper(tensor, dim, new_size):
#     """Util for truncating"""
#     input_format = get_memory_format(tensor)
#     ndim = tensor.ndim
#     dim = (dim + ndim) % ndim
#     output_slice = [
#         slice(0, x) if idx != dim else slice(0, new_size)
#         for idx, x in enumerate(tensor.shape)
#     ]
#     tensor_trunc = tensor[output_slice].contiguous(memory_format=input_format)

#     return tensor_trunc


# def split_tensor_along_dim(tensor, dim, num_chunks):
#     """splits tensor along specific dim"""
#     assert (
#         dim < tensor.dim()
#     ), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
#     assert (
#         tensor.shape[dim] % num_chunks == 0
#     ), f"Error, cannot split dim {dim} evenly. Dim size is \
#                                                   {tensor.shape[dim]} and requested numnber of splits is {num_chunks}"
#     chunk_size = tensor.shape[dim] // num_chunks
#     tensor_list = torch.split(tensor, chunk_size, dim=dim)

#     return tensor_list



[docs]@torch.no_grad()
def gather_loss(loss: float, dst_rank: int = 0, mean: bool = True):
    """Gathers loss from all processes to one for logging

    Parameters
    ----------
    loss : float
        loss value
    dst_rank : int, optional
        destination rank to gather to, by default 0
    mean : bool, optional
        Calculate the mean of the losses gathered, by default True

    Raises
    ------
    Exception
        If DistributedManager has yet to be initialized
    """
    if not DistributedManager.is_initialized():
        raise Exception(
            "Distributed manager should be initialized when using gather_loss"
        )

    distmng = DistributedManager()
    loss = torch.Tensor([loss])

    # For serial runs, just return the current loss!
    if distmng.world_size == 1:
        return float(loss)

    # Gather using PyTorch distributed function
    gather_list = None
    if distmng.rank == dst_rank:
        gather_list = [
            torch.zeros(1).to(distmng.device) for i in range(distmng.world_size)
        ]
    dist.gather(loss.to(distmng.device), gather_list, dst_rank)

    # Return loss if dst_rank, None otherwise
    if distmng.rank == dst_rank:
        loss = torch.sum(torch.cat(gather_list))
        if mean:
            loss = loss / distmng.world_size
        return float(loss.cpu())
    else:
        return None
# # distributed primitives
# def _transpose(tensor, dim0, dim1, group=None, async_op=False):
#     # get input format
#     input_format = get_memory_format(tensor)

#     # get comm params
#     comm_size = dist.get_world_size(group=group)

#     # split and local transposition
#     split_size = tensor.shape[dim0] // comm_size
#     x_send = [
#         y.contiguous(memory_format=input_format)
#         for y in torch.split(tensor, split_size, dim=dim0)
#     ]
#     x_recv = [torch.empty_like(x_send[0]) for _ in range(comm_size)]

#     # global transposition
#     req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)

#     return x_recv, req


# def _reduce(input_, use_fp32=True, group=None):
#     """All-reduce the input tensor across model parallel group."""

#     # Bypass the function if we are using only 1 GPU.
#     if dist.get_world_size(group=group) == 1:
#         return input_

#     # All-reduce.
#     if use_fp32:
#         dtype = input_.dtype
#         inputf_ = input_.float()
#         dist.all_reduce(inputf_, group=group)
#         input_ = inputf_.to(dtype)
#     else:
#         dist.all_reduce(input_, group=group)

#     return input_


# def _split(input_, dim_, group=None):
#     """Split the tensor along its last dimension and keep the corresponding slice."""
#     # get input format
#     input_format = get_memory_format(input_)

#     # Bypass the function if we are using only 1 GPU.
#     comm_size = dist.get_world_size(group=group)
#     if comm_size == 1:
#         return input_

#     # Split along last dimension.
#     input_list = split_tensor_along_dim(input_, dim_, comm_size)

#     # Note: torch.split does not create contiguous tensors by default.
#     rank = dist.get_rank(group=group)
#     output = input_list[rank].contiguous(memory_format=input_format)

#     return output


# def _gather(input_, dim_, group=None):
#     """Gather tensors and concatinate along the last dimension."""
#     # get input format
#     input_format = get_memory_format(input_)

#     comm_size = dist.get_world_size(group=group)
#     # Bypass the function if we are using only 1 GPU.
#     if comm_size == 1:
#         return input_

#     # sanity checks
#     assert (
#         dim_ < input_.dim()
#     ), f"Error, cannot gather along {dim_} for tensor with {input_.dim()} dimensions."

#     # Size and dimension.
#     comm_rank = dist.get_rank(group=group)

#     tensor_list = [torch.empty_like(input_) for _ in range(comm_size)]
#     tensor_list[comm_rank] = input_
#     dist.all_gather(tensor_list, input_, group=group)

#     # Note: torch.cat already creates a contiguous tensor.
#     output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format)

#     return output
