Source code for modulus.distributed.utils
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO this also needs more docstrings
import torch
import torch.nn.functional as F
import torch.distributed as dist
from .manager import DistributedManager
# def get_memory_format(tensor):
# """Gets format for tensor"""
# if tensor.is_contiguous(memory_format=torch.channels_last):
# return torch.channels_last
# else:
# return torch.contiguous_format
# def pad_helper(tensor, dim, new_size, mode="zero"):
# """Util for padding tensors"""
# ndim = tensor.ndim
# dim = (dim + ndim) % ndim
# ndim_pad = ndim - dim
# output_shape = [0 for _ in range(2 * ndim_pad)]
# orig_size = tensor.shape[dim]
# output_shape[1] = new_size - orig_size
# tensor_pad = F.pad(tensor, output_shape, mode="constant", value=0.0)
# if mode == "conj":
# lhs_slice = [
# slice(0, x) if idx != dim else slice(orig_size, new_size)
# for idx, x in enumerate(tensor.shape)
# ]
# rhs_slice = [
# slice(0, x) if idx != dim else slice(1, output_shape[1] + 1)
# for idx, x in enumerate(tensor.shape)
# ]
# tensor_pad[lhs_slice] = torch.flip(
# torch.conj(tensor_pad[rhs_slice]), dims=[dim]
# )
# return tensor_pad
# def truncate_helper(tensor, dim, new_size):
# """Util for truncating"""
# input_format = get_memory_format(tensor)
# ndim = tensor.ndim
# dim = (dim + ndim) % ndim
# output_slice = [
# slice(0, x) if idx != dim else slice(0, new_size)
# for idx, x in enumerate(tensor.shape)
# ]
# tensor_trunc = tensor[output_slice].contiguous(memory_format=input_format)
# return tensor_trunc
# def split_tensor_along_dim(tensor, dim, num_chunks):
# """splits tensor along specific dim"""
# assert (
# dim < tensor.dim()
# ), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
# assert (
# tensor.shape[dim] % num_chunks == 0
# ), f"Error, cannot split dim {dim} evenly. Dim size is \
# {tensor.shape[dim]} and requested numnber of splits is {num_chunks}"
# chunk_size = tensor.shape[dim] // num_chunks
# tensor_list = torch.split(tensor, chunk_size, dim=dim)
# return tensor_list
def gather_loss(loss: float, dst_rank: int = 0, mean: bool = True):
"""Gathers loss from all processes to one for logging
loss : float
loss value
dst_rank : int, optional
destination rank to gather to, by default 0
mean : bool, optional
Calculate the mean of the losses gathered, by default True
If DistributedManager has yet to be initialized
if not DistributedManager.is_initialized():
raise Exception(
"Distributed manager should be initialized when using gather_loss"
distmng = DistributedManager()
loss = torch.Tensor([loss])
# For serial runs, just return the current loss!
if distmng.world_size == 1:
return float(loss)
# Gather using PyTorch distributed function
gather_list = None
if distmng.rank == dst_rank:
gather_list = [
torch.zeros(1).to(distmng.device) for i in range(distmng.world_size)
dist.gather(, gather_list, dst_rank)
# Return loss if dst_rank, None otherwise
if distmng.rank == dst_rank:
loss = torch.sum(
if mean:
loss = loss / distmng.world_size
return float(loss.cpu())
return None# # distributed primitives
# def _transpose(tensor, dim0, dim1, group=None, async_op=False):
# # get input format
# input_format = get_memory_format(tensor)
# # get comm params
# comm_size = dist.get_world_size(group=group)
# # split and local transposition
# split_size = tensor.shape[dim0] // comm_size
# x_send = [
# y.contiguous(memory_format=input_format)
# for y in torch.split(tensor, split_size, dim=dim0)
# ]
# x_recv = [torch.empty_like(x_send[0]) for _ in range(comm_size)]
# # global transposition
# req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
# return x_recv, req
# def _reduce(input_, use_fp32=True, group=None):
# """All-reduce the input tensor across model parallel group."""
# # Bypass the function if we are using only 1 GPU.
# if dist.get_world_size(group=group) == 1:
# return input_
# # All-reduce.
# if use_fp32:
# dtype = input_.dtype
# inputf_ = input_.float()
# dist.all_reduce(inputf_, group=group)
# input_ =
# else:
# dist.all_reduce(input_, group=group)
# return input_
# def _split(input_, dim_, group=None):
# """Split the tensor along its last dimension and keep the corresponding slice."""
# # get input format
# input_format = get_memory_format(input_)
# # Bypass the function if we are using only 1 GPU.
# comm_size = dist.get_world_size(group=group)
# if comm_size == 1:
# return input_
# # Split along last dimension.
# input_list = split_tensor_along_dim(input_, dim_, comm_size)
# # Note: torch.split does not create contiguous tensors by default.
# rank = dist.get_rank(group=group)
# output = input_list[rank].contiguous(memory_format=input_format)
# return output
# def _gather(input_, dim_, group=None):
# """Gather tensors and concatinate along the last dimension."""
# # get input format
# input_format = get_memory_format(input_)
# comm_size = dist.get_world_size(group=group)
# # Bypass the function if we are using only 1 GPU.
# if comm_size == 1:
# return input_
# # sanity checks
# assert (
# dim_ < input_.dim()
# ), f"Error, cannot gather along {dim_} for tensor with {input_.dim()} dimensions."
# # Size and dimension.
# comm_rank = dist.get_rank(group=group)
# tensor_list = [torch.empty_like(input_) for _ in range(comm_size)]
# tensor_list[comm_rank] = input_
# dist.all_gather(tensor_list, input_, group=group)
# # Note: already creates a contiguous tensor.
# output =, dim=dim_).contiguous(memory_format=input_format)
# return output