Modulus Distributed

Core v0.3.0
class modulus.distributed.manager.DistributedManager[source]

Bases: object

Distributed Manager for setting up distributed training enviroment.

This is a singleton that creates a persistance class instance for storing parallel environment information through out the life time of the program. This should be used to help set up Distributed Data Parallel and parallel datapipes.

Note

One should call DistributedManager.initialize() prior to constructing a manager object

Example

Copy
Copied!
            

>>> DistributedManager.initialize() >>> manager = DistributedManager() >>> manager.rank 0 >>> manager.world_size 1

property broadcast_buffers

broadcast_buffers in PyTorch DDP

static cleanup()[source]

Clean up distributed group and singleton

static create_orthogonal_process_group(name: str, group_name: str, verbose: bool = False)[source]

Create a process group that is orthogonal to the specified process group.

Parameters
  • name (str) –

  • created. (Name of the process group to be) –

  • group_name (str) –

  • group. (Name of the existing process) –

  • verbose (bool) –

  • group (Print out ranks of each created process) –

  • False. (default) –

static create_process_subgroup(name: str, size: int, group_name: Optional[str] = None, verbose: bool = False)[source]

Create a process subgroup of a parent process group. This must be a collective call by all processes participating in this application.

Parameters
  • name (str) –

  • created. (Name of the process subgroup to be) –

  • size (int) –

  • of (Size of the process subgroup to be created. This must be an integer factor) –

  • size. (the parent group’s) –

  • group_name (Optional[str]) –

  • group (Print out ranks of each created process) –

  • None (optional. If) –

  • group

  • None. (will be used. Default) –

  • verbose (bool) –

  • group

  • False. (default) –

property cuda

If cuda is available

property device

Process device

property distributed

Distributed enviroment

property find_unused_parameters

find_unused_parameters in PyTorch DDP

static get_available_backend()[source]

Get communication backend

group(name=None)[source]

Returns a process group with the given name If name is None, group is also None indicating the default process group If named group does not exist, returns None also

group_name(group=None)[source]

Returns the name of process group

property group_names

Returns a list of all named process groups created

group_rank(name=None)[source]

Returns the rank in named process group

group_size(name=None)[source]

Returns the size of named process group

static initialize()[source]

Initialize distributed manager

Current supported initialization methods are:
ENV: PyTorch environment variable initialization

https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization

SLURM: Initialization on SLURM systems.

Uses SLURM_PROCID, SLURM_NPROCS, SLURM_LOCALID and SLURM_LAUNCH_NODE_IPADDR environment variables.

OPENMPI: Initialization for OpenMPI launchers.

Uses OMPI_COMM_WORLD_RANK, OMPI_COMM_WORLD_SIZE and OMPI_COMM_WORLD_LOCAL_RANK environment variables.

Initialization by default is done using the first valid method in the order listed above. Initialization method can also be explicitly controlled using the MODULUS_DISTRIBUTED_INITIALIZATION_METHOD environment variable and setting it to one of the options above.

static initialize_env()[source]

Setup method using generic initialization

static initialize_open_mpi(addr, port)[source]

Setup method using OpenMPI initialization

static initialize_slurm(port)[source]

Setup method using SLURM initialization

classmethod is_initialized() → bool[source]

If manager singleton has been initialized

property local_rank

Process rank on local machine

property rank

Process rank

static setup(rank=0, world_size=1, local_rank=None, addr='localhost’, port='12355', backend='nccl’, method='env’)[source]

Set up PyTorch distributed process group and update manager attributes

property world_size

Number of processes in distributed enviroment

modulus.distributed.utils.all_gather_v_wrapper(tensor: Tensor, sizes: List[int], dim: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]

Implements a distributed AllGatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive gathers all local tensors from each rank into the full global tensor onto each rank.

Parameters
  • tensor (“torch.Tensor”) – local tensor on each rank

  • sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank

  • dim (int, optional) – dimension along which global tensor is distributed, by default 0

  • group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None

Returns

full global tensor, valid on each rank

Return type

torch.Tensor

modulus.distributed.utils.all_reduce_v_wrapper(tensor: Tensor, sizes: List[int], dim: int = 0, use_fp32: bool = True, group: Optional[ProcessGroup] = None) → Tensor[source]

Implements a distributed AllReduceV primitive. It is based on the idea of a single global tensor which which can be distributed along a specified dimension into chunks of variable size. This primitive assumes different global tensors of the same shape on each rank. It then re-distributes chunks of all these tensors such that each rank receives all corresponding parts of a global tensor. Each rank then sums up the chunks after receiving it. By design, this primitive thus implements the backward pass of the “all_gather_v” primitive. In this case, the result would be a single global gradient tensor distributed onto different ranks.

Parameters
  • tensor (torch.Tensor) – global tensor on each rank (different one on each rank)

  • sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank

  • dim (int, optional) – dimension along which global tensor is distributed, by default 0

  • use_fp32 (bool, optional) – flag to specify FP32 precision for the redcution, by default True

  • group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None

Returns

local tensor, i.e. result of reduction of all corresponding chunks from all global tensors for each rank separately

Return type

torch.Tensor

modulus.distributed.utils.gather_loss(loss: float, dst_rank: int = 0, mean: bool = True)[source]

Gathers loss from all processes to one for logging

Parameters
  • loss (float) – loss value

  • dst_rank (int, Optional) – destination rank to gather to, by default 0

  • mean (bool, Optional) – Calculate the mean of the losses gathered, by default True

Raises

Exception – If DistributedManager has yet to be initialized

modulus.distributed.utils.gather_v_wrapper(tensor: Tensor, sizes: List[int], dim: int = 0, dst: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]

Implements a distributed GatherV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes such a distributed tensor and gathers all local tensors from each rank into the full global tensor valid on the specified destination rank.

Parameters
  • tensor (torch.Tensor) – local tensor on each rank

  • sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank

  • dim (int, optional) – dimension along which global tensor is distributed, by default 0

  • dst (int, optional) – destination rank which contains the full global tensor after the operation, by default 0

  • group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None

Returns

full global tensor, valid on destination rank

Return type

torch.Tensor

modulus.distributed.utils.get_memory_format(tensor)[source]

Gets format for tensor

modulus.distributed.utils.indexed_all_to_all_v_wrapper(tensor: Tensor, indices: List[Tensor], sizes: List[List[int]], dim: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]

Implements an indexed version of a distributed AllToAllV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive assumes a set of indices into this dimension which indicate the corresponding slices sent to each other rank forming an indexed version of an AllToAllV primitive.

Parameters
  • tensor (torch.Tensor) – local part of global tensor on each rank

  • indices (List[torch.Tensor]) – list of indices on each rank of slices being sent to each other rank from this rank

  • sizes (List[List[int]]) – number of indices each rank sends to each other rank, valid and set on each rank, e.g. sizes[0][3] corresponds to the number of slices rank 0 sends to rank 3

  • dim (int) – dimension along which global tensor is distributed, by default 0

  • group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None

Returns

local result of primitive corresponding to indexed global tensor

Return type

torch.Tensor

modulus.distributed.utils.indexed_all_to_all_v_wrapper_bwd(tensor: Tensor, indices: List[Tensor], sizes: List[List[int]], tensor_size_along_dim: int, use_fp32: bool = True, dim: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]

Implements the backward pass to the indexed version of a distributed AllToAllV primitive.

Parameters
  • tensor (torch.Tensor) – local tensor, i.e. gradient on resulting tensor from forward pass

  • indices (List[torch.Tensor]) – list of indices on each rank of slices being sent to each other rank from this rank

  • sizes (List[List[int]]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set on each rank

  • tensor_size_along_dim (int) – size of original local tensor along specified dimension, i.e. from the corresponding forward pass

  • use_fp32 (bool, optional) – flag to specify FP32 precision, by default True

  • dim (int, optional) – dimension along with global tensor is distributed, by default 0

  • group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None

Returns

result of primitive corresponding to indexed global tensor

Return type

torch.Tensor

modulus.distributed.utils.pad_helper(tensor, dim, new_size, mode='zero’)[source]

Util for padding tensors

modulus.distributed.utils.scatter_v_wrapper(tensor: Tensor, sizes: List[int], dim: int = 0, src: int = 0, group: Optional[ProcessGroup] = None) → Tensor[source]

Implements a distributed ScatterV primitive. It is based on the idea of a single global tensor which is distributed along a specified dimension into chunks of variable size. This primitive scatters the global tensor from a specified source rank into local chunks onto each other rank.

Parameters
  • tensor (torch.Tensor) – global tensor, valid on source rank

  • sizes (List[int]) – list of the sizes of each chunk on each rank along distributed dimension, valid and set each rank

  • dim (int, optional) – dimension along which global tensor is distributed, by default 0

  • src (int, optional) – source rank of primitive, i.e. rank of original full global tensor, by default 0

  • group (Optional[dist.ProcessGroup], optional) – process group along which global tensor is shared, by default None

Returns

corresponding local part of the global tensor on each rank

Return type

torch.Tensor

modulus.distributed.utils.split_tensor_along_dim(tensor, dim, num_chunks)[source]

splits tensor along specific dim

modulus.distributed.utils.truncate_helper(tensor, dim, new_size)[source]

Util for truncating

Previous Modulus Deploy
Next Modulus Utils
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jan 25, 2024.