Modulus Distributed
Distributed utilites in Modulus are designed to simplify implementation of parallel training and
make inference scripts easier by providing a unified way to configure and query parameters associated
with the distributed environment. The utilites in modulus.distributed
build on top of the
utilites from torch.distributed
and abstract out some of the complexities of setting up a
distributed execution environment.
The example below shows how to setup a simple distributed data parallel training recipe using the
distributed utilites in Modulus.
DistributedDataParallel
in PyTorch provides the framework for data parallel training by reducing parameter gradients
across multiple worker processes after the backwards pass. The code below shows how to specify
the device_ids
, output_device
, broadcast_buffers
and find_unused_parameters
arguments of the DistributedDataParallel
utility using the DistributedManager
.
import torch
from torch.nn.parallel import DistributedDataParallel
from modulus.distributed import DistributedManager
from modulus.models.mlp.fully_connected import FullyConnected
def main():
# Initialize the DistributedManager. This will automatically
# detect the number of processes the job was launched with and
# set those configuration parameters appropriately. Currently
# torchrun (or any other pytorch compatible launcher), mpirun (OpenMPI)
# and SLURM based launchers are supported.
DistributedManager.initialize()
# Since this is a singleton class, you can just get an instance
# of it anytime after initialization and not need to reinitialize
# each time.
dist = DistributedManager()
# Set up model on the appropriate device. DistributedManager
# figures out what device should be used on this process
arch = FullyConnected(in_features=32, out_features=64).to(dist.device)
# Set up DistributedDataParallel if using more than a single process.
# The `distributed` property of DistributedManager can be used to
# check this.
if dist.distributed:
ddps = torch.cuda.Stream()
with torch.cuda.stream(ddps):
arch = DistributedDataParallel(
arch,
device_ids=[dist.local_rank], # Set the device_id to be
# the local rank of this process on
# this node
output_device=dist.device,
broadcast_buffers=dist.broadcast_buffers,
find_unused_parameters=dist.find_unused_parameters,
)
torch.cuda.current_stream().wait_stream(ddps)
# Set up the optimizer
optimizer = torch.optim.Adam(
arch.parameters(),
lr=0.001,
)
def training_step(input, target):
pred = arch(invar)
loss = torch.sum(torch.pow(pred - target, 2))
loss.backward()
optimizer.step()
return loss
# Sample training loop
for i in range(20):
# Random inputs and targets for simplicity
input = torch.randn(128, 32, device=dist.device)
target = torch.randn(128, 64, device=dist.device)
# Training step
loss = training_step(input, target)
if __name__ == "__main__":
main()
This training script can be run on a single GPU
using python train.py
or on multiple GPUs using
torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> train.py
or
mpirun -np <num_gpus> python train.py
if using OpenMPI. The script can also be run on a SLURM cluster using
srun -n <num_gpus> python train.py
An important aspect of the DistributedManager
is that it is follows the
Borg pattern.
This means that DistributedManager
essentially functions like a singleton
class and once configured, all utilities in Modulus can access the same configuration
and adapt to the specified distributed structure.
For example, see the constructor of the DistributedAFNO
class:
def __init__(
self,
inp_shape: Tuple[int, int],
in_channels: int,
out_channels: Union[int, Any] = None,
patch_size: int = 16,
embed_dim: int = 256,
depth: int = 4,
num_blocks: int = 4,
channel_parallel_inputs: bool = False,
channel_parallel_outputs: bool = False,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
if DistributedManager().group("model_parallel") is None:
raise RuntimeError(
"Distributed AFNO needs to have model parallel group created first. "
"Check the MODEL_PARALLEL_SIZE environment variable"
)
comm_size = DistributedManager().group_size("model_parallel")
if channel_parallel_inputs:
if not (in_channels % comm_size == 0):
raise ValueError(
"Error, in_channels needs to be divisible by model_parallel size"
)
self._impl = DistributedAFNONet(
inp_shape=inp_shape,
patch_size=(patch_size, patch_size),
in_chans=in_channels,
out_chans=out_channels,
embed_dim=embed_dim,
depth=depth,
num_blocks=num_blocks,
input_is_matmul_parallel=False,
output_is_matmul_parallel=False,
)
This model parallel implementation can just instantiate DistributedManager
and query
if the process group named "model_parallel"
exists and if so, what is it’s size. Similarly,
other utilities can query what device to run on, the total size of the distributed run, etc.
without having to explicitly pass those params down the call stack.
This singleton/borg pattern is very useful for the DistributedManager
since it takes charge
of bootstrapping the distributed run and unifies how all utilities become aware of the distributed
configuration. However, the singleton/borg pattern is not just a way to avoid passing parameters
to utilities. Use of this pattern should be limited and have good justification to avoid losing
tracability and keep the code readable.
modulus.distributed.manager
- class modulus.distributed.manager.DistributedManager[source]
Bases:
object
Distributed Manager for setting up distributed training environment.
This is a singleton that creates a persistance class instance for storing parallel environment information through out the life time of the program. This should be used to help set up Distributed Data Parallel and parallel datapipes.
NoteOne should call DistributedManager.initialize() prior to constructing a manager object
Example
>>> DistributedManager.initialize() >>> manager = DistributedManager() >>> manager.rank 0 >>> manager.world_size 1
- property broadcast_buffers
broadcast_buffers in PyTorch DDP
- static cleanup()[source]
Clean up distributed group and singleton
- static create_orthogonal_process_group(orthogonal_group_name: str, group_name: str, verbose: bool = False)[source]
Create a process group that is orthogonal to the specified process group.
- Parameters
orthogonal_group_name (str) – Name of the orthogonal process group to be created.
group_name (str) – Name of the existing process group.
verbose (bool) – Print out ranks of each created process group, default False.
- static create_process_subgroup(name: str, size: int, group_name: Optional[str] = None, verbose: bool = False)[source]
Create a process subgroup of a parent process group. This must be a collective call by all processes participating in this application.
- Parameters
name (str) – Name of the process subgroup to be created.
size (int) – Size of the process subgroup to be created. This must be an integer factor of the parent group’s size.
group_name (Optional[str]) – Name of the parent process group, optional. If None, the default process group will be used. Default None.
verbose (bool) – Print out ranks of each created process group, default False.
- property cuda
If cuda is available
- property device
Process device
- property distributed
Distributed enviroment
- property find_unused_parameters
find_unused_parameters in PyTorch DDP
- static get_available_backend()[source]
Get communication backend
- group(name=None)[source]
Returns a process group with the given name If name is None, group is also None indicating the default process group If named group does not exist, ModulusUndefinedGroupError exception is raised
- group_name(group=None)[source]
Returns the name of process group
- property group_names
Returns a list of all named process groups created
- group_rank(name=None)[source]
Returns the rank in named process group
- group_size(name=None)[source]
Returns the size of named process group
- static initialize()[source]
Initialize distributed manager
- Current supported initialization methods are:
- ENV: PyTorch environment variable initialization
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
- SLURM: Initialization on SLURM systems.
Uses SLURM_PROCID, SLURM_NPROCS, SLURM_LOCALID and SLURM_LAUNCH_NODE_IPADDR environment variables.
- OPENMPI: Initialization for OpenMPI launchers.
Uses OMPI_COMM_WORLD_RANK, OMPI_COMM_WORLD_SIZE and OMPI_COMM_WORLD_LOCAL_RANK environment variables.
Initialization by default is done using the first valid method in the order listed above. Initialization method can also be explicitly controlled using the MODULUS_DISTRIBUTED_INITIALIZATION_METHOD environment variable and setting it to one of the options above.
- static initialize_env()[source]
Setup method using generic initialization
- static initialize_open_mpi(addr, port)[source]
Setup method using OpenMPI initialization
- static initialize_slurm(port)[source]
Setup method using SLURM initialization
- classmethod is_initialized() → bool[source]
If manager singleton has been initialized
- property local_rank
Process rank on local machine
- property rank
Process rank
- static setup(rank=0, world_size=1, local_rank=None, addr='localhost', port='12355', backend='nccl', method='env')[source]
Set up PyTorch distributed process group and update manager attributes
- property world_size
Number of processes in distributed enviroment
- exception modulus.distributed.manager.ModulusUndefinedGroupError(name: str)[source]
Bases:
Exception
Exception for querying an undefined process group using the Modulus DistributedManager
- exception modulus.distributed.manager.ModulusUninitializedDistributedManagerWarning[source]
Bases:
Warning
Warning to indicate usage of an uninitialized DistributedManager
modulus.distributed.utils
- modulus.distributed.utils.all_gather_v_bwd_wrapper(tensor: Tensor, sizes: List[int], dim: int = 0, use_fp32: bool = True, group: Optional[ProcessGroup] = None) → Tensor[source]
Implements a distributed AllReduceV primitive. It is based on the idea of a single global tensor which which can be distributed along a specified dimension into chunks of variable size. This primitive assumes different global tensors of the same shape on each rank. It then re-distributes chunks of all these tensors such that each rank receives all corresponding parts of a global tensor. Each rank then sums up the chunks after receiving it. By design, this primitive thus implements the backward pass of the “all_gather_v” primitive. In this case, the result would be a single global gradient tensor distributed onto different ranks.
- Parameters
tensor (torch.Tensor) – global tensor on each rank (different one on each rank)
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
use_fp32 (bool, optional) – flag to specify reduction taking place at least in FP32 precision, by default True only acts on floating point inputs in lower precision
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns
local tensor, i.e. result of reduction of all corresponding chunks from all global tensors for each rank separately
- Return type
torch.Tensor
- modulus.distributed.utils.all_gather_v_wrapper(tensor: Tensor, sizes: Optional[List[int]] = None, dim: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]
Implements a distributed AllGatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive gathers all local tensors from each rank into the full global tensor onto each rank.
- Parameters
tensor ("torch.Tensor") – local tensor on each rank
sizes (List[int], optional) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank, by default None
dim (int, optional) – dimension along which global tensor is distributed, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns
full global tensor, valid on each rank
- Return type
torch.Tensor
- modulus.distributed.utils.distributed_transpose(tensor, dim0, dim1, group=None, async_op=False)[source]
Perform distributed transpose of tensor to switch sharding dimension
- modulus.distributed.utils.gather_v_wrapper(tensor: Tensor, sizes: List[int], dim: int = 0, dst: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]
Implements a distributed GatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes such a distributed tensor and gathers all local tensors from each rank into the full global tensor valid on the specified destination rank.
- Parameters
tensor (torch.Tensor) – local tensor on each rank
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
dst (int, optional) – destination rank which contains the full global tensor after the operation, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns
full global tensor, valid on destination rank
- Return type
torch.Tensor
- modulus.distributed.utils.get_memory_format(tensor)[source]
Gets format for tensor
- modulus.distributed.utils.indexed_all_to_all_v_wrapper(tensor: Tensor, indices: List[Tensor], sizes: List[List[int]], dim: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]
Implements an indexed version of a distributed AllToAllV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes a set of indices into this dimension which indicate the corresponding slices sent to each other rank forming an indexed version of an AllToAllV primitive.
- Parameters
tensor (torch.Tensor) – local part of global tensor on each rank
indices (List[torch.Tensor]) – list of indices on each rank of slices being sent to each other rank from this rank
sizes (List[List[int]]) – number of indices each rank sends to each other rank, valid and set on each rank, e.g. sizes[0][3] corresponds to the number of slices rank 0 sends to rank 3
dim (int) – dimension along which global tensor is distributed, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns
local result of primitive corresponding to indexed global tensor
- Return type
torch.Tensor
- modulus.distributed.utils.indexed_all_to_all_v_wrapper_bwd(tensor: Tensor, indices: List[Tensor], sizes: List[List[int]], tensor_size_along_dim: int, use_fp32: bool = True, dim: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]
Implements the backward pass to the indexed version of a distributed AllToAllV primitive.
- Parameters
tensor (torch.Tensor) – local tensor, i.e. gradient on resulting tensor from forward pass
indices (List[torch.Tensor]) – list of indices on each rank of slices being sent to each other rank from this rank
sizes (List[List[int]]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank
tensor_size_along_dim (int) – size of original local tensor along specified dimension, i.e. from the corresponding forward pass
use_fp32 (bool, optional) – flag to specify reduction taking place at least in FP32 precision, by default True only acts on floating point inputs in lower precision
dim (int, optional) – dimension along with global tensor is distributed, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns
result of primitive corresponding to indexed global tensor
- Return type
torch.Tensor
Helper function to mark parameters of a module as being shared across ranks by attaching gradient hooks to the corresponding tensors.
- Parameters
module (nn.Module) – PyTorch module which is to be marked as having shared parameters.
process_group (str | None) – str indicating process_group which contains ranks across which the module’s parameters are shared. If passed as None, will default to the world group.
recurse (bool, default=True) – Flag indicating whether the module’s parameters are traversed in a recursive fashion, i.e. whether sub-modules are also considered as having shared parameters.
use_fp32_reduction (bool, default=True) – Flag indicating whether the reduction for accumulating gradients will be done in at least FP32 or the native datatype.
- modulus.distributed.utils.pad_helper(tensor, dim, new_size, mode='zero')[source]
Util for padding tensors
- modulus.distributed.utils.reduce_loss(loss: float, dst_rank: int = 0, mean: bool = True)[source]
Reduces loss from all processes to destination rank for logging.
- Parameters
loss (float) – loss value
dst_rank (int, Optional) – destination rank to redce to, by default 0.
mean (bool, Optional) – Calculate the mean of the losses gathered, by default True.
- Raises
Exception – If DistributedManager has yet to be initialized
- modulus.distributed.utils.scatter_v_wrapper(tensor: Tensor, sizes: List[int], dim: int = 0, src: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]
Implements a distributed ScatterV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive scatters the global tensor from a specified source rank into local chunks onto each other rank.
- Parameters
tensor (torch.Tensor) – global tensor, valid on source rank
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
src (int, optional) – source rank of primitive, i.e. rank of original full global tensor, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns
corresponding local part of the global tensor on each rank
- Return type
torch.Tensor
- modulus.distributed.utils.truncate_helper(tensor, dim, new_size)[source]
Util for truncating
Helper function to unmark parameters of a module as being shared across ranks by removing attached gradient hooks.
- Parameters
module (nn.Module) – PyTorch module which is to be unmarked as having shared parameters.
recurse (bool, default=True) – Flag indicating whether the module’s parameters are traversed in a recursive fashion, i.e. whether sub-modules are also considered as having shared parameters.
modulus.distributed.autograd
- class modulus.distributed.autograd.AllGatherVAutograd(*args, **kwargs)[source]
Bases:
Function
Autograd Wrapper for a distributed AllGatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive gathers all local tensors from each rank into the full global tensor onto each rank. Its indended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass performs an AllReduceV operation where each rank gathers its corresponding chunk of a global tensor from each other rank and sums up these individual gradients.
- static backward(ctx, grad_output: Tensor)[source]
backward pass of the of the Distributed AllGatherV primitive
- static forward(ctx, tensor: Tensor, sizes: List[int], dim: int = 0, use_fp32: bool = True, group: Optional[ProcessGroup] = None) → Tensor[source]
forward pass of the Distributed AllGatherV primitive
- class modulus.distributed.autograd.GatherVAutograd(*args, **kwargs)[source]
Bases:
Function
Autograd Wrapper for a distributed GatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes such a distributed tensor and gathers all local tensors from each rank into the full global tensor valid on the specified destination rank. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass corresponds to a straightforward ScatterV primitive distributing the global gradient from the specified destination rank to all the other ranks.
- static backward(ctx, grad_output: Tensor) → Tensor[source]
backward pass of the Distributed GatherV primitive
- static forward(ctx, tensor: Tensor, sizes: List[int], dim: int = 0, dst: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]
forward pass of the distributed GatherV primitive
- class modulus.distributed.autograd.IndexedAllToAllVAutograd(*args, **kwargs)[source]
Bases:
Function
Autograd Wrapper for an Indexed AllToAllV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes a set of indices into this dimension which indicate the corresponding slices sent to each other rank forming an indexed version of an AllToAllV primitive. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass more or less corresponds to the same operation as in the forward pass but with reversed roles and does an additional reduction of gathered gradients so that each rank finally will compute the overall gradient on its local tensor partition.
- static backward(ctx, grad_output: Tensor) → Tensor[source]
backward pass of the Distributed IndexedAlltoAllV primitive
- static forward(ctx, tensor: Tensor, indices: List[Tensor], sizes: List[List[int]], use_fp32: bool = True, dim: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]
forward pass of the Distributed IndexedAlltoAllV primitive
- class modulus.distributed.autograd.ScatterVAutograd(*args, **kwargs)[source]
Bases:
Function
Autograd Wrapper for Distributed ScatterV. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive scatters the global tensor from a specified source rank into local chunks onto each other rank. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass corresponds to an GatherV primitive gathering local gradients from all the other ranks into a single global gradient on the specified source rank.
- static backward(ctx, grad_output: Tensor) → Tensor[source]
backward pass of the Distributed ScatterV primitive
- static forward(ctx, tensor: Tensor, sizes: List[int], dim: int = 0, src: int = 0, group=typing.Optional[torch.distributed.distributed_c10d.ProcessGroup]) → Tensor[source]
forward pass of the Distributed ScatterV primitive
- modulus.distributed.autograd.all_gather_v(tensor: Tensor, sizes: List[int], dim: int = 0, use_fp32: bool = True, group: Optional[ProcessGroup] = None) → Tensor[source]
Autograd Wrapper for a distributed AllGatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive gathers all local tensors from each rank into the full global tensor onto each rank. Its indended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass performs an AllReduceV operation where each rank gathers its corresponding chunk of a global tensor from each other rank and sums up these individual gradients.
- Parameters
tensor ("torch.Tensor") – local tensor on each rank
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
use_fp32 (bool, optional) – boolean flag to indicate whether to use FP32 precision for the reduction in the backward pass, by default True
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns
full global tensor, valid on each rank
- Return type
torch.Tensor
- modulus.distributed.autograd.gather_v(tensor: Tensor, sizes: List[int], dim: int = 0, dst: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]
Autograd Wrapper for a distributed GatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes such a distributed tensor and gathers all local tensors from each rank into the full global tensor valid on the specified destination rank. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass corresponds to a straightforward ScatterV primitive distributing the global gradient from the specified destination rank to all the other ranks.
- Parameters
tensor (torch.Tensor) – local tensor on each rank
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
dst (int, optional) – destination rank which contains the full global tensor after the operation, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns
full global tensor, valid on destination rank
- Return type
torch.Tensor
- modulus.distributed.autograd.indexed_all_to_all_v(tensor: Tensor, indices: List[Tensor], sizes: List[List[int]], use_fp32: bool = True, dim: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]
Autograd Wrapper for an Indexed AllToAllV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes a set of indices into this dimension which indicate the corresponding slices sent to each other rank forming an indexed version of an AllToAllV primitive. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass more or less corresponds to the same operation as in the forward pass but with reversed roles and does an additional reduction of gathered gradients so that each rank finally will compute the overall gradient on its local tensor partition.
- Parameters
tensor (torch.Tensor) – local part of global tensor on each rank
indices (List[torch.Tensor]) – list of indices on each rank of slices being sent to each other rank from this rank
sizes (List[List[int]]) – number of indices each rank sends to each other rank, valid and set on each rank, e.g. sizes[0][3] corresponds to the number of slices rank 0 sends to rank 3
use_fp32 (bool, optional) – flag to specify whether to use FP32 precision in the reduction in the backward pass, by default True
dim (int) – dimension along which global tensor is distributed, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns
local result of primitive corresponding to indexed global tensor
- Return type
torch.Tensor
- modulus.distributed.autograd.scatter_v(tensor: Tensor, sizes: List[int], dim: int = 0, src: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]
Autograd Wrapper for Distributed ScatterV. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive scatters the global tensor from a specified source rank into local chunks onto each other rank. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass corresponds to an GatherV primitive gathering local gradients from all the other ranks into a single global gradient on the specified source rank.
- Parameters
tensor (torch.Tensor) – global tensor, valid on source rank
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
src (int, optional) – source rank of primitive, i.e. rank of original full global tensor, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns
corresponding local part of the global tensor on each rank
- Return type
torch.Tensor
modulus.distributed.fft
- class modulus.distributed.fft.DistributedIRFFT2(*args, **kwargs)[source]
Bases:
Function
Autograd Wrapper for a distributed 2D real to complex IFFT primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of equal size. This primitive computes a 1D IFFT first along dim[1], then performs an AllToAll transpose before computing a 1D FFT along dim[0]. The backward pass performs an FFT operation with communication in the opposite order as in the forward pass.
For the forward method, data should be split along dim[0] across the “spatial_parallel” process group. The output is data split in dim[1].
- static backward(ctx, grad_output)[source]
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, x, s, dim, norm='ortho')[source]
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
- class modulus.distributed.fft.DistributedRFFT2(*args, **kwargs)[source]
Bases:
Function
Autograd Wrapper for a distributed 2D real to complex FFT primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of equal size. This primitive computes a 1D FFT first along dim[0], then performs an AllToAll transpose before computing a 1D FFT along dim[1]. The backward pass performs an IFFT operation with communication in the opposite order as in the forward pass.
For the forward method, data should be split along dim[1] across the “spatial_parallel” process group. The output is data split in dim[0].
- static backward(ctx, grad_output)[source]
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjp
function.)It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- static forward(ctx, x, s, dim, norm='ortho')[source]
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See combining-forward-context for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()
staticmethod to handle setting up thectx
object.output
is the output of the forward,inputs
are a Tuple of inputs to the forward.See extending-autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.
modulus.distributed.mappings
- modulus.distributed.mappings.copy_to_parallel_region(input, group)[source]
Copy input
- modulus.distributed.mappings.gather_from_parallel_region(input, dim, shapes, group)[source]
Gather the input from matmul parallel region and concatenate.
- modulus.distributed.mappings.reduce_from_parallel_region(input, group)[source]
All-reduce the input from the matmul parallel region.
- modulus.distributed.mappings.scatter_to_parallel_region(input, dim, group)[source]
Split the input and keep only the corresponding chuck to the rank.