# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from modulus.distributed.manager import DistributedManager from modulus.distributed.mappings import ( gather_from_parallel_region , scatter_to_parallel_region , ) from modulus.distributed.utils import distributed_transpose , pad_helper , truncate_helper def conj_pad_helper_2d ( tensor , pad_dim , other_dim , new_size ): ndim = tensor . ndim pad_dim = ( pad_dim + ndim ) % ndim other_dim = ( other_dim + ndim ) % ndim # pad with conj orig_size = tensor . shape [ pad_dim ] tensor_pad = pad_helper ( tensor , pad_dim , new_size , mode = "conj" ) # gather tensor_pad_gather = gather_from_parallel_region ( tensor_pad , dim = other_dim , group = "spatial_parallel" ) # flip dims flip_slice = [ slice ( 0 , x ) if (( idx != pad_dim ) and ( idx != other_dim )) else slice ( orig_size , new_size ) if ( idx == pad_dim ) else slice ( 1 , x ) for idx , x in enumerate ( tensor_pad_gather . shape ) ] tensor_pad_gather [ flip_slice ] = torch . flip ( tensor_pad_gather [ flip_slice ], dims = [ other_dim ] ) # truncate: result = scatter_to_parallel_region ( tensor_pad_gather , dim = other_dim , group = "spatial_parallel" ) return result [docs] class DistributedRFFT2 ( torch . autograd . Function ): """ Autograd Wrapper for a distributed 2D real to complex FFT primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of equal size. This primitive computes a 1D FFT first along dim[0], then performs an AllToAll transpose before computing a 1D FFT along dim[1]. The backward pass performs an IFFT operation with communication in the opposite order as in the forward pass. For the forward method, data should be split along dim[1] across the "spatial_parallel" process group. The output is data split in dim[0]. """ [docs] @staticmethod def forward ( ctx , x , s , dim , norm = "ortho" ): # NVTX marker torch . cuda . nvtx . range_push ( "DistributedRFFT2.forward" ) # save: ctx . s = s ctx . dim = dim ctx . norm = norm # assume last dim is split (second to last is contiguous): x1 = torch . fft . fft ( x , n = s [ 0 ], dim = dim [ 0 ], norm = norm ) torch . cuda . nvtx . range_pop () # transpose x1_recv , _ = distributed_transpose ( x1 , dim [ 0 ], dim [ 1 ], group = DistributedManager () . group ( "spatial_parallel" ), async_op = False , ) x1_tran = torch . cat ( x1_recv , dim = dim [ 1 ]) torch . cuda . nvtx . range_pop () # another fft: x2 = torch . fft . fft ( x1_tran , n = s [ 1 ], dim = dim [ 1 ], norm = norm ) torch . cuda . nvtx . range_pop () # truncate in last dim: ctx . last_dim_size = x2 . shape [ dim [ 1 ]] last_dim_size_trunc = ctx . last_dim_size // 2 + 1 output = truncate_helper ( x2 , dim [ 1 ], last_dim_size_trunc ) # pop range torch . cuda . nvtx . range_pop () return output [docs] @staticmethod def backward ( ctx , grad_output ): # load dim = ctx . dim norm = ctx . norm s = ctx . s last_dim_size = ctx . last_dim_size # pad the input to perform the backward fft g_pad = pad_helper ( grad_output , dim [ 1 ], last_dim_size ) # do fft g1 = torch . fft . ifft ( g_pad , n = s [ 1 ], dim = dim [ 1 ], norm = norm ) # transpose g1_recv , _ = distributed_transpose ( g1 , dim [ 1 ], dim [ 0 ], group = DistributedManager () . group ( "spatial_parallel" ), async_op = False , ) g1_tran = torch . cat ( g1_recv , dim = dim [ 0 ]) # now do the BW fft: grad_input = torch . real ( torch . fft . ifft ( g1_tran , n = s [ 0 ], dim = dim [ 0 ], norm = norm )) return grad_input , None , None , None [docs] class DistributedIRFFT2 ( torch . autograd . Function ): """ Autograd Wrapper for a distributed 2D real to complex IFFT primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of equal size. This primitive computes a 1D IFFT first along dim[1], then performs an AllToAll transpose before computing a 1D FFT along dim[0]. The backward pass performs an FFT operation with communication in the opposite order as in the forward pass. For the forward method, data should be split along dim[0] across the "spatial_parallel" process group. The output is data split in dim[1]. """ [docs] @staticmethod def forward ( ctx , x , s , dim , norm = "ortho" ): # NVTX marker torch . cuda . nvtx . range_push ( "DistributedIRFFT2.forward" ) # save: ctx . s = s ctx . dim = dim ctx . norm = norm ctx . orig_dim_size = x . shape [ dim [ 1 ]] if s is not None : first_dim_size = s [ 0 ] ctx . last_dim_size = s [ 1 ] else : first_dim_size = x . shape [ dim [ 0 ]] ctx . last_dim_size = 2 * ( ctx . orig_dim_size - 1 ) # fft in contig contig dim x_pad = conj_pad_helper_2d ( x , dim [ 1 ], dim [ 0 ], ctx . last_dim_size ) x1 = torch . fft . ifft ( x_pad , n = ctx . last_dim_size , dim = dim [ 1 ], norm = norm ) # transpose x1_recv , _ = distributed_transpose ( x1 , dim [ 1 ], dim [ 0 ], group = DistributedManager () . group ( "spatial_parallel" ), async_op = False , ) x1_tran = torch . cat ( x1_recv , dim = dim [ 0 ]) # ifft in contig dim x2 = torch . fft . ifft ( x1_tran , n = first_dim_size , dim = dim [ 0 ], norm = norm ) # take real part output = torch . real ( x2 ) . contiguous () # pop range torch . cuda . nvtx . range_pop () return output [docs] @staticmethod def backward ( ctx , grad_output ): # load dim = ctx . dim norm = ctx . norm orig_dim_size = ctx . orig_dim_size # do fft g1 = torch . fft . fft ( grad_output , dim = dim [ 0 ], norm = norm ) # transpose g1_recv , _ = distributed_transpose ( g1 , dim [ 0 ], dim [ 1 ], group = DistributedManager () . group ( "spatial_parallel" ), async_op = False , ) g1_tran = torch . cat ( g1_recv , dim = dim [ 1 ]) # now do the BW fft: x2 = torch . fft . fft ( g1_tran , dim = dim [ 1 ], norm = norm ) # truncate grad_input = truncate_helper ( x2 , dim [ 1 ], orig_dim_size ) return grad_input , None , None , None