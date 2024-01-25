# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import

torch

from

modulus.distributed.manager

import

DistributedManager

from

modulus.distributed.utils

import

_gather

,

_reduce

,

_split

class

_CopyToParallelRegion

(

torch

.

autograd

.

Function

):

"""Pass the input to the parallel region"""

@staticmethod

def

symbolic

(

graph

,

input_

,

group_

):

return

input_

@staticmethod

def

forward

(

ctx

,

input_

,

group_

):

ctx

.

group

=

group_

return

input_

@staticmethod

def

backward

(

ctx

,

grad_output

):

return

_reduce

(

grad_output

,

group

=

DistributedManager

()

.

group

(

ctx

.

group

))

class

_ReduceFromParallelRegion

(

torch

.

autograd

.

Function

):

"""All-reduce the input from the parallel region"""

@staticmethod

def

symbolic

(

graph

,

input_

,

group_

):

return

_reduce

(

input_

,

group

=

DistributedManager

()

.

group

(

group_

))

@staticmethod

def

forward

(

ctx

,

input_

,

group_

):

return

_reduce

(

input_

,

group

=

DistributedManager

()

.

group

(

group_

))

@staticmethod

def

backward

(

ctx

,

grad_output

):

return

grad_output

class

_ScatterToParallelRegion

(

torch

.

autograd

.

Function

):

"""Split the input and keep only the chunk corresponding to the rank."""

@staticmethod

def

symbolic

(

graph

,

input_

,

dim_

,

group_

):

return

_split

(

input_

,

dim_

,

group

=

DistributedManager

()

.

group

(

group_

))

@staticmethod

def

forward

(

ctx

,

input_

,

dim_

,

group_

):

ctx

.

dim

=

dim_

ctx

.

group

=

group_

return

_split

(

input_

,

dim_

,

group

=

DistributedManager

()

.

group

(

group_

))

@staticmethod

def

backward

(

ctx

,

grad_output

):

return

(

_gather

(

grad_output

,

ctx

.

dim

,

group

=

DistributedManager

()

.

group

(

ctx

.

group_

)),

None

,

)

class

_GatherFromParallelRegion

(

torch

.

autograd

.

Function

):

"""Gather the input from parallel region and concatenate."""

@staticmethod

def

symbolic

(

graph

,

input_

,

dim_

,

group_

):

return

_gather

(

input_

,

dim_

,

group

=

DistributedManager

()

.

group

(

group_

))

@staticmethod

def

forward

(

ctx

,

input_

,

dim_

,

group_

):

ctx

.

dim

=

dim_

ctx

.

group

=

group_

return

_gather

(

input_

,

dim_

,

group

=

DistributedManager

()

.

group

(

group_

))

@staticmethod

def

backward

(

ctx

,

grad_output

):

return

(

_split

(

grad_output

,

ctx

.

dim

,

group

=

DistributedManager

()

.

group

(

ctx

.

group

)),

None

,

)

class

_GatherWithinParallelRegion

(

torch

.

autograd

.

Function

):

"""

Gather the input within parallel region and concatenate.

The same forward method as _GatherFromParallelRegion, the difference is only in the

backward pass. This method performs a reduction of the gradients before the split in

the backward pass while the other version only performs a split

"""

@staticmethod

def

symbolic

(

graph

,

input_

,

dim_

,

group_

):

return

_gather

(

input_

,

dim_

,

group

=

DistributedManager

()

.

group

(

group_

))

@staticmethod

def

forward

(

ctx

,

input_

,

dim_

,

group_

):

ctx

.

dim

=

dim_

ctx

.

group

=

group_

return

_gather

(

input_

,

dim_

,

group

=

DistributedManager

()

.

group

(

group_

))

@staticmethod

def

backward

(

ctx

,

grad_output

):

red

=

_reduce

(

grad_output

,

group

=

DistributedManager

()

.

group

(

ctx

.

group_

))

return

(

_split

(

red

,

ctx

.

dim

,

group

=

DistributedManager

()

.

group

(

ctx

.

group

)),

None

,

)

# -----------------

# Helper functions.

# -----------------

[docs] def copy_to_parallel_region ( input , group ): """Copy input""" return _CopyToParallelRegion . apply ( input , group )

[docs] def reduce_from_parallel_region ( input , group ): """All-reduce the input from the matmul parallel region.""" return _ReduceFromParallelRegion . apply ( input , group )

[docs] def scatter_to_parallel_region ( input , dim , group ): """Split the input and keep only the corresponding chuck to the rank.""" return _ScatterToParallelRegion . apply ( input , dim , group )

[docs] def gather_from_parallel_region ( input , dim , group ): """Gather the input from matmul parallel region and concatenate.""" return _GatherFromParallelRegion . apply ( input , dim , group )