PhysicsNeMo Distributed#
Distributed utilities in PhysicsNeMo are designed to simplify the implementation of parallel training and
make inference scripts easier by providing a unified way to configure and query parameters associated
with the distributed environment. The utilities in physicsnemo.distributed
build on top of the
utilities from torch.distributed
and abstract out some of the complexities of setting up a
distributed execution environment.
The example below shows how to set up a simple distributed data parallel training recipe using the
distributed utilities in PhysicsNeMo.
DistributedDataParallel
in PyTorch provides the framework for data parallel training by reducing parameter gradients
across multiple worker processes after the backwards pass. The code below shows how to specify
the device_ids
, output_device
, broadcast_buffers
and find_unused_parameters
arguments of the DistributedDataParallel
utility using the DistributedManager
.
import torch
from torch.nn.parallel import DistributedDataParallel
from physicsnemo.distributed import DistributedManager
from physicsnemo.models.mlp.fully_connected import FullyConnected
def main():
# Initialize the DistributedManager. This will automatically
# detect the number of processes the job was launched with and
# set those configuration parameters appropriately. Currently
# torchrun (or any other pytorch compatible launcher), mpirun (OpenMPI)
# and SLURM based launchers are supported.
DistributedManager.initialize()
# Since this is a singleton class, you can just get an instance
# of it anytime after initialization and not need to reinitialize
# each time.
dist = DistributedManager()
# Set up model on the appropriate device. DistributedManager
# figures out what device should be used on this process
arch = FullyConnected(in_features=32, out_features=64).to(dist.device)
# Set up DistributedDataParallel if using more than a single process.
# The `distributed` property of DistributedManager can be used to
# check this.
if dist.distributed:
ddps = torch.cuda.Stream()
with torch.cuda.stream(ddps):
arch = DistributedDataParallel(
arch,
device_ids=[dist.local_rank], # Set the device_id to be
# the local rank of this process on
# this node
output_device=dist.device,
broadcast_buffers=dist.broadcast_buffers,
find_unused_parameters=dist.find_unused_parameters,
)
torch.cuda.current_stream().wait_stream(ddps)
# Set up the optimizer
optimizer = torch.optim.Adam(
arch.parameters(),
lr=0.001,
)
def training_step(input, target):
pred = arch(invar)
loss = torch.sum(torch.pow(pred - target, 2))
loss.backward()
optimizer.step()
return loss
# Sample training loop
for i in range(20):
# Random inputs and targets for simplicity
input = torch.randn(128, 32, device=dist.device)
target = torch.randn(128, 64, device=dist.device)
# Training step
loss = training_step(input, target)
if __name__ == "__main__":
main()
This training script can be run on a single GPU
using python train.py
or on multiple GPUs using
torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> train.py
or
mpirun -np <num_gpus> python train.py
if using OpenMPI. The script can also be run on a SLURM cluster using
srun -n <num_gpus> python train.py
How does this work?#
An important aspect of the DistributedManager
is that it follows the
Borg pattern.
This means that DistributedManager
essentially functions like a singleton
class and once configured, all utilities in PhysicsNeMo can access the same configuration
and adapt to the specified distributed structure.
For example, see the constructor of the DistributedAFNO
class:
def __init__(
self,
inp_shape: Tuple[int, int],
in_channels: int,
out_channels: Union[int, Any] = None,
patch_size: int = 16,
embed_dim: int = 256,
depth: int = 4,
num_blocks: int = 4,
channel_parallel_inputs: bool = False,
channel_parallel_outputs: bool = False,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
if DistributedManager().group("model_parallel") is None:
raise RuntimeError(
"Distributed AFNO needs to have model parallel group created first. "
"Check the MODEL_PARALLEL_SIZE environment variable"
)
comm_size = DistributedManager().group_size("model_parallel")
if channel_parallel_inputs:
if not (in_channels % comm_size == 0):
raise ValueError(
"Error, in_channels needs to be divisible by model_parallel size"
)
self._impl = DistributedAFNONet(
inp_shape=inp_shape,
patch_size=(patch_size, patch_size),
in_chans=in_channels,
out_chans=out_channels,
embed_dim=embed_dim,
depth=depth,
num_blocks=num_blocks,
input_is_matmul_parallel=False,
output_is_matmul_parallel=False,
)
This model parallel implementation can just instantiate DistributedManager
and query
if the process group named "model_parallel"
exists and if so, what its size is. Similarly,
other utilities can query what device to run on, the total size of the distributed run, etc.
without having to explicitly pass those parameters down the call stack.
Note
This singleton/borg pattern is very useful for the DistributedManager
since it takes charge
of bootstrapping the distributed run and unifies how all utilities become aware of the distributed
configuration. However, the singleton/borg pattern is not just a way to avoid passing parameters
to utilities. Use of this pattern should be limited and have good justification to avoid losing
traceability and keep the code readable.
physicsnemo.distributed.manager#
- class physicsnemo.distributed.manager.DistributedManager[source]#
Bases:
object
Distributed Manager for setting up distributed training environment.
This is a singleton that creates a persistance class instance for storing parallel environment information through out the life time of the program. This should be used to help set up Distributed Data Parallel and parallel datapipes.
Note
One should call DistributedManager.initialize() prior to constructing a manager object
Example
>>> DistributedManager.initialize() >>> manager = DistributedManager() >>> manager.rank 0 >>> manager.world_size 1
- property broadcast_buffers#
broadcast_buffers in PyTorch DDP
- static create_orthogonal_process_group(
- orthogonal_group_name: str,
- group_name: str,
- verbose: bool = False,
Create a process group that is orthogonal to the specified process group.
- Parameters:
orthogonal_group_name (str) – Name of the orthogonal process group to be created.
group_name (str) – Name of the existing process group.
verbose (bool) – Print out ranks of each created process group, default False.
- static create_process_subgroup(
- name: str,
- size: int,
- group_name: str | None = None,
- verbose: bool = False,
Create a process subgroup of a parent process group. This must be a collective call by all processes participating in this application.
- Parameters:
name (str) – Name of the process subgroup to be created.
size (int) – Size of the process subgroup to be created. This must be an integer factor of the parent group’s size.
group_name (Optional[str]) – Name of the parent process group, optional. If None, the default process group will be used. Default None.
verbose (bool) – Print out ranks of each created process group, default False.
- property cuda#
If cuda is available
- property device#
Process device
- property distributed#
Distributed environment
- property find_unused_parameters#
find_unused_parameters in PyTorch DDP
- get_mesh_group(
- mesh: DeviceMesh,
Get the process group for a given mesh.
Creating a group is an expensive operation, so we cache the result manually.
We hash the mesh and use that as the key.
- property global_mesh#
Returns the global mesh. If it’s not initialized, it will be created when this is called.
- group(name=None)[source]#
Returns a process group with the given name If name is None, group is also None indicating the default process group If named group does not exist, PhysicsNeMoUndefinedGroupError exception is raised
- property group_names#
Returns a list of all named process groups created
- static initialize()[source]#
Initialize distributed manager
- Current supported initialization methods are:
- ENV: PyTorch environment variable initialization
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
- SLURM: Initialization on SLURM systems.
Uses SLURM_PROCID, SLURM_NPROCS, SLURM_LOCALID and SLURM_LAUNCH_NODE_IPADDR environment variables.
- OPENMPI: Initialization for OpenMPI launchers.
Uses OMPI_COMM_WORLD_RANK, OMPI_COMM_WORLD_SIZE and OMPI_COMM_WORLD_LOCAL_RANK environment variables.
Initialization by default is done using the first valid method in the order listed above. Initialization method can also be explicitly controlled using the PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD environment variable and setting it to one of the options above.
- initialize_mesh(
- mesh_shape: Tuple[int, ...],
- mesh_dim_names: Tuple[str, ...],
Initialize a global device mesh over the entire distributed job.
Creates a multi-dimensional mesh of processes that can be used for distributed operations. The mesh shape must multiply to equal the total world size, with one dimension optionally being flexible (-1).
- Parameters:
mesh_shape (Tuple[int, ...]) – Tuple of ints describing the size of each mesh dimension. Product must equal world_size. One dimension can be -1 to be automatically calculated.
mesh_dim_names (Tuple[str, ...]) – Names for each mesh dimension. Must match length of mesh_shape.
- Returns:
The initialized device mesh
- Return type:
torch.distributed.DeviceMesh
- Raises:
RuntimeError – If mesh dimensions are invalid or don’t match world size
AssertionError – If distributed environment is not available
- property local_rank#
Process rank on local machine
- mesh(name=None)[source]#
Return a device_mesh with the given name. Does not initialize. If the mesh is not created already, will raise and error
- Parameters:
name (str, optional) – Name of desired mesh, by default None
- property mesh_dims#
size)
- Type:
Mesh Dimensions as dictionary (axis name
- property rank#
Process rank
- static setup(
- rank=0,
- world_size=1,
- local_rank=None,
- addr='localhost',
- port='12355',
- backend='nccl',
- method='env',
Set up PyTorch distributed process group and update manager attributes
- property world_size#
Number of processes in distributed environment
physicsnemo.distributed.utils#
- physicsnemo.distributed.utils.all_gather_v_bwd_wrapper(
- tensor: Tensor,
- sizes: List[int],
- dim: int = 0,
- use_fp32: bool = True,
- group: ProcessGroup | None = None,
Implements a distributed AllReduceV primitive. It is based on the idea of a single global tensor which which can be distributed along a specified dimension into chunks of variable size. This primitive assumes different global tensors of the same shape on each rank. It then re-distributes chunks of all these tensors such that each rank receives all corresponding parts of a global tensor. Each rank then sums up the chunks after receiving it. By design, this primitive thus implements the backward pass of the “all_gather_v” primitive. In this case, the result would be a single global gradient tensor distributed onto different ranks.
- Parameters:
tensor (torch.Tensor) – global tensor on each rank (different one on each rank)
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
use_fp32 (bool, optional) – flag to specify reduction taking place at least in FP32 precision, by default True only acts on floating point inputs in lower precision
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns:
local tensor, i.e. result of reduction of all corresponding chunks from all global tensors for each rank separately
- Return type:
torch.Tensor
- physicsnemo.distributed.utils.all_gather_v_wrapper(
- tensor: Tensor,
- sizes: List[int] | None = None,
- dim: int = 0,
- group: ProcessGroup | None = None,
Implements a distributed AllGatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive gathers all local tensors from each rank into the full global tensor onto each rank.
- Parameters:
tensor ("torch.Tensor") – local tensor on each rank
sizes (List[int], optional) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank, by default None. Can be single integer per rank (assuming all other dimensions except dim below are equal) or can be full
dim (int, optional) – dimension along which global tensor is distributed, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns:
full global tensor, valid on each rank
- Return type:
torch.Tensor
- physicsnemo.distributed.utils.distributed_transpose(tensor, dim0, dim1, group=None, async_op=False)[source]#
Perform distributed transpose of tensor to switch sharding dimension
- physicsnemo.distributed.utils.gather_v_wrapper(
- tensor: Tensor,
- sizes: List[int],
- dim: int = 0,
- dst: int = 0,
- group: ProcessGroup | None = None,
Implements a distributed GatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes such a distributed tensor and gathers all local tensors from each rank into the full global tensor valid on the specified destination rank.
- Parameters:
tensor (torch.Tensor) – local tensor on each rank
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
dst (int, optional) – destination rank which contains the full global tensor after the operation, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns:
full global tensor, valid on destination rank
- Return type:
torch.Tensor
- physicsnemo.distributed.utils.indexed_all_to_all_v_wrapper(
- tensor: Tensor,
- indices: List[Tensor],
- sizes: List[List[int]],
- dim: int = 0,
- group: ProcessGroup | None = None,
Implements an indexed version of a distributed AllToAllV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes a set of indices into this dimension which indicate the corresponding slices sent to each other rank forming an indexed version of an AllToAllV primitive.
- Parameters:
tensor (torch.Tensor) – local part of global tensor on each rank
indices (List[torch.Tensor]) – list of indices on each rank of slices being sent to each other rank from this rank
sizes (List[List[int]]) – number of indices each rank sends to each other rank, valid and set on each rank, e.g. sizes[0][3] corresponds to the number of slices rank 0 sends to rank 3
dim (int) – dimension along which global tensor is distributed, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns:
local result of primitive corresponding to indexed global tensor
- Return type:
torch.Tensor
- physicsnemo.distributed.utils.indexed_all_to_all_v_wrapper_bwd(
- tensor: Tensor,
- indices: List[Tensor],
- sizes: List[List[int]],
- tensor_size_along_dim: int,
- use_fp32: bool = True,
- dim: int = 0,
- group: ProcessGroup | None = None,
Implements the backward pass to the indexed version of a distributed AllToAllV primitive.
- Parameters:
tensor (torch.Tensor) – local tensor, i.e. gradient on resulting tensor from forward pass
indices (List[torch.Tensor]) – list of indices on each rank of slices being sent to each other rank from this rank
sizes (List[List[int]]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank
tensor_size_along_dim (int) – size of original local tensor along specified dimension, i.e. from the corresponding forward pass
use_fp32 (bool, optional) – flag to specify reduction taking place at least in FP32 precision, by default True only acts on floating point inputs in lower precision
dim (int, optional) – dimension along with global tensor is distributed, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns:
result of primitive corresponding to indexed global tensor
- Return type:
torch.Tensor
- module: Module,
- process_group: str | None,
- recurse: bool = True,
- use_fp32_reduction: bool = True,
Helper function to mark parameters of a module as being shared across ranks by attaching gradient hooks to the corresponding tensors.
- Parameters:
module (nn.Module) – PyTorch module which is to be marked as having shared parameters.
process_group (str | None) – str indicating process_group which contains ranks across which the module’s parameters are shared. If passed as None, will default to the world group.
recurse (bool, default=True) – Flag indicating whether the module’s parameters are traversed in a recursive fashion, i.e. whether sub-modules are also considered as having shared parameters.
use_fp32_reduction (bool, default=True) – Flag indicating whether the reduction for accumulating gradients will be done in at least FP32 or the native datatype.
- physicsnemo.distributed.utils.pad_helper(tensor, dim, new_size, mode='zero')[source]#
Util for padding tensors
- physicsnemo.distributed.utils.reduce_loss(loss: float, dst_rank: int = 0, mean: bool = True)[source]#
Reduces loss from all processes to destination rank for logging.
- Parameters:
loss (float) – loss value
dst_rank (int, Optional) – destination rank to redce to, by default 0.
mean (bool, Optional) – Calculate the mean of the losses gathered, by default True.
- Raises:
Exception – If DistributedManager has yet to be initialized
- physicsnemo.distributed.utils.scatter_v_wrapper(
- tensor: Tensor,
- sizes: List[int],
- dim: int = 0,
- src: int = 0,
- group: ProcessGroup | None = None,
Implements a distributed ScatterV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive scatters the global tensor from a specified source rank into local chunks onto each other rank.
- Parameters:
tensor (torch.Tensor) – global tensor, valid on source rank
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
src (int, optional) – source rank of primitive, i.e. rank of original full global tensor, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns:
corresponding local part of the global tensor on each rank
- Return type:
torch.Tensor
- module: Module,
- recurse: bool = True,
Helper function to unmark parameters of a module as being shared across ranks by removing attached gradient hooks.
- Parameters:
module (nn.Module) – PyTorch module which is to be unmarked as having shared parameters.
recurse (bool, default=True) – Flag indicating whether the module’s parameters are traversed in a recursive fashion, i.e. whether sub-modules are also considered as having shared parameters.
physicsnemo.distributed.autograd#
- class physicsnemo.distributed.autograd.AllGatherVAutograd(*args, **kwargs)[source]#
Bases:
Function
Autograd Wrapper for a distributed AllGatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive gathers all local tensors from each rank into the full global tensor onto each rank. Its indended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass performs an AllReduceV operation where each rank gathers its corresponding chunk of a global tensor from each other rank and sums up these individual gradients.
- class physicsnemo.distributed.autograd.GatherVAutograd(*args, **kwargs)[source]#
Bases:
Function
Autograd Wrapper for a distributed GatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes such a distributed tensor and gathers all local tensors from each rank into the full global tensor valid on the specified destination rank. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass corresponds to a straightforward ScatterV primitive distributing the global gradient from the specified destination rank to all the other ranks.
- class physicsnemo.distributed.autograd.IndexedAllToAllVAutograd(*args, **kwargs)[source]#
Bases:
Function
Autograd Wrapper for an Indexed AllToAllV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes a set of indices into this dimension which indicate the corresponding slices sent to each other rank forming an indexed version of an AllToAllV primitive. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass more or less corresponds to the same operation as in the forward pass but with reversed roles and does an additional reduction of gathered gradients so that each rank finally will compute the overall gradient on its local tensor partition.
- class physicsnemo.distributed.autograd.ScatterVAutograd(*args, **kwargs)[source]#
Bases:
Function
Autograd Wrapper for Distributed ScatterV. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive scatters the global tensor from a specified source rank into local chunks onto each other rank. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass corresponds to an GatherV primitive gathering local gradients from all the other ranks into a single global gradient on the specified source rank.
- physicsnemo.distributed.autograd.all_gather_v(
- tensor: Tensor,
- sizes: List[int],
- dim: int = 0,
- use_fp32: bool = True,
- group: ProcessGroup | None = None,
Autograd Wrapper for a distributed AllGatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive gathers all local tensors from each rank into the full global tensor onto each rank. Its indended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass performs an AllReduceV operation where each rank gathers its corresponding chunk of a global tensor from each other rank and sums up these individual gradients.
- Parameters:
tensor ("torch.Tensor") – local tensor on each rank
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
use_fp32 (bool, optional) – boolean flag to indicate whether to use FP32 precision for the reduction in the backward pass, by default True
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns:
full global tensor, valid on each rank
- Return type:
torch.Tensor
- physicsnemo.distributed.autograd.gather_v(
- tensor: Tensor,
- sizes: List[int],
- dim: int = 0,
- dst: int = 0,
- group: ProcessGroup | None = None,
Autograd Wrapper for a distributed GatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes such a distributed tensor and gathers all local tensors from each rank into the full global tensor valid on the specified destination rank. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass corresponds to a straightforward ScatterV primitive distributing the global gradient from the specified destination rank to all the other ranks.
- Parameters:
tensor (torch.Tensor) – local tensor on each rank
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
dst (int, optional) – destination rank which contains the full global tensor after the operation, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns:
full global tensor, valid on destination rank
- Return type:
torch.Tensor
- physicsnemo.distributed.autograd.indexed_all_to_all_v(
- tensor: Tensor,
- indices: List[Tensor],
- sizes: List[List[int]],
- use_fp32: bool = True,
- dim: int = 0,
- group: ProcessGroup | None = None,
Autograd Wrapper for an Indexed AllToAllV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes a set of indices into this dimension which indicate the corresponding slices sent to each other rank forming an indexed version of an AllToAllV primitive. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass more or less corresponds to the same operation as in the forward pass but with reversed roles and does an additional reduction of gathered gradients so that each rank finally will compute the overall gradient on its local tensor partition.
- Parameters:
tensor (torch.Tensor) – local part of global tensor on each rank
indices (List[torch.Tensor]) – list of indices on each rank of slices being sent to each other rank from this rank
sizes (List[List[int]]) – number of indices each rank sends to each other rank, valid and set on each rank, e.g. sizes[0][3] corresponds to the number of slices rank 0 sends to rank 3
use_fp32 (bool, optional) – flag to specify whether to use FP32 precision in the reduction in the backward pass, by default True
dim (int) – dimension along which global tensor is distributed, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns:
local result of primitive corresponding to indexed global tensor
- Return type:
torch.Tensor
- physicsnemo.distributed.autograd.scatter_v(
- tensor: Tensor,
- sizes: List[int],
- dim: int = 0,
- src: int = 0,
- group: ProcessGroup | None = None,
Autograd Wrapper for Distributed ScatterV. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive scatters the global tensor from a specified source rank into local chunks onto each other rank. It is intended to be used in tensor-parallel settings on tensors which require gradients to be passed through. The backward pass corresponds to an GatherV primitive gathering local gradients from all the other ranks into a single global gradient on the specified source rank.
- Parameters:
tensor (torch.Tensor) – global tensor, valid on source rank
sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set each rank
dim (int, optional) – dimension along which global tensor is distributed, by default 0
src (int, optional) – source rank of primitive, i.e. rank of original full global tensor, by default 0
group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None
- Returns:
corresponding local part of the global tensor on each rank
- Return type:
torch.Tensor
physicsnemo.distributed.fft#
- class physicsnemo.distributed.fft.DistributedIRFFT2(*args, **kwargs)[source]#
Bases:
Function
Autograd Wrapper for a distributed 2D real to complex IFFT primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of equal size. This primitive computes a 1D IFFT first along dim[1], then performs an AllToAll transpose before computing a 1D FFT along dim[0]. The backward pass performs an FFT operation with communication in the opposite order as in the forward pass.
For the forward method, data should be split along dim[0] across the “spatial_parallel” process group. The output is data split in dim[1].
- class physicsnemo.distributed.fft.DistributedRFFT2(*args, **kwargs)[source]#
Bases:
Function
Autograd Wrapper for a distributed 2D real to complex FFT primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of equal size. This primitive computes a 1D FFT first along dim[0], then performs an AllToAll transpose before computing a 1D FFT along dim[1]. The backward pass performs an IFFT operation with communication in the opposite order as in the forward pass.
For the forward method, data should be split along dim[1] across the “spatial_parallel” process group. The output is data split in dim[0].
physicsnemo.distributed.mappings#
- physicsnemo.distributed.mappings.gather_from_parallel_region(input, dim, shapes, group)[source]#
Gather the input from matmul parallel region and concatenate.