PhysicsNeMo Distributed#

Distributed Training is a broad, active, and evolving area of AI research and applications. This guide serves as an entry point for distributed training (and inference) in the context of Scientific AI.

Scaling AI for scientific workloads has important differences when compared to Large Language Model (LLM) parallelism and not just because of model differences. Datasets for SciML are also different compared to vision and language datasets.

Instead of ordered tokens, or pixelized images, SciML workloads process meshes, point clouds, massive connected graphs, high resolution 3D datasets, or other unique and complicated formats.

Scaling Strategies in Scientific AI#

Many state of the art workloads for scientific AI are not multi-billion-parameter LLMs, but smaller scale models such as Graph Neural Networks (GNNs) or Neural Operators. Parameter counts are typically below one billion parameters, which makes the model parallelism techniques of DeepSpeed and Pytorch FSDP, irrelevant for small SciML models.

Note

Not all SciML models are small, however. Making model parallelism techniques _very_ relevant for multi-billion parameter scale models.

Instead, for SciML workloads, the predominant parallelisms employeed currently are:

  • data parallelism, where a large batch size is processed over multiple GPUs in parallel, with each GPU taking at least one training example in it’s entirety.

  • domain parallelism, a further division of data parallelism where a single example of training data is divided between multiple GPUs and processes cooperatively.

Data Parallelism is typically enabled by tools such as PyTorch DistributedDataParallel (DDP ), while Domain Parallelism has typically been driven by customized and targeted implementations of collectives in bespoke models.

There are other parallelism strategies that are also valuable in SciML. Pipeline Parallelism, for example, can be an extremely powerful scaling strategy; especially for inference workloads, where the pipeline can be nearly saturated.

Note

Another common parallelism in AI models is tensor parallelism. While generically “tensor parallelism” could be used to refer to any program that is distributing tensors across multiple devices, the term has been popularized by techniques such as NVIDIA Megatron. In this context, we’ll use the term “tensor parallelism” with the Megatron meaning. Here it will refer to partitioning both weights and inputs carefully to enable massive parameter-count models.

PhysicsNeMo aims to make parallelism for scientific AI simpler and easier. We provide high level tools, such as the physicsnemo.distributed.DistributedManager, which greatly simplifies set up and orchestration of parallel jobs. There are also low level collective functions, such as gather_v, all_gather_v, scatter_v, and indexed_all_to_all_v, which enable differentiable collectives on irregular data types. For more information, refer to physicsnemo.distributed.autograd.

Several PhysicsnNeMo models are designed with domain parallelism in mind. SFNO and MeshGraphNet can both operate natively with domain parallelism.

Work continues toward developing a more generic and extensible domain parallelism technique called ShardTensor, which enables domain parallelism in generic models. For more information on ShardTensor, see the Domain Parallelism tutorials Domain Parallelism and Shard Tensor.

Distributing Data#

One challenge of SciML training is ensuring efficient and scalable data loading. Some techniques, like PyTorch’s data loader, can work well for many formats of data. But for point clouds, mesh data, graphs, you can take advantage of physicsnemo.datapipes.

Typically, for scaling large and sometimes irregular data, we recommend utilizing high performance data storage such as:

Other popular and scalable data storage formats are:

(using h5py) * NetCDF

Glossary#

Distributed AI can be complex and is evolving. Below is a glossary of the most common distributed terms as the relate to SciML and some useful tips about relevant PhysicsNeMo functionality.

Glossary
Collectives#

Also known as collective functions, these are the primatives of distributed training and inference. They enable communication of tensors between GPUs. Refer to PyTorch Docs for details of the core primatives.

Tip

PhysicsNeMo also offers several _differentiable_ collectives in physicsnemo.distributed.autograd, and through the ShardTensor interface.

Data Parallel#

Replicate the full model on each process (rank), feed each rank a unique batch of data and synchronize gradients using all-reduce each step.

Tip

Use when the model fits on a single GPU and you want to scale batch size and throughput. A common implementation is PyTorch DistributedDataParallel.

DeepSpeed#

A deep learning optimization library that provides ZeRO optimizer sharding, pipeline/3D parallelism, activation checkpointing, and communication/memory optimizations for large-scale training. Refer to DeepSpeed for more information.

DeviceMesh#

A PyTorch object that represents a set of GPUs, organized analagous to a multi-dimensional tensor, that can be used to enable multi-dimensional parallelism. Learn more.

DistributedDataParallel#

DistributedDataParallel (DDP, in PyTorch) is synchronous data parallel training that averages gradients across ranks each iteration using all-reduce. Efficiently overlays communication with computation for excellent weak scaling. Learn more.

Domain Parallel#

Partition the computational domain (for example, grid/mesh) across devices where each GPU computes on its subdomain and exchanges Halo and border data with neighbors as necessary.

Tip

You can utilize Domain Parallelism in PhysicsNeMo with ShardTensor.

FSDP#

FullyShardedDataParallel (PyTorch) shards parameters, gradients, and optimizer states across ranks to reduce memory usage and enable larger models and batches.

Intra-Node#

Parallelism within a single server across multiple local GPUs (for example, connected by PCIe, NVLink/NVSwitch), typically with lower latency and higher bandwidth.

Inter-Node#

Parallelism across multiple servers over a network fabric (for example, InfiniBand/Ethernet), with higher latency and more sensitivity to communication overhead.

Local Rank#

Each node participating in a distributed application will have one or more ranks. The local rank is a unique index within a node.

Tip

In PhysicsNeMo, you can access the local rank value with dm.local_rank.

Model Parallel#

Split a single model across multiple devices so different parts of the model (layers or tensors) live on different GPUs. This includes pipeline parallel and tensor parallel (intra-layer) parallelism.

Node#

A single physical or virtual machine (host) in a cluster. A node typically has multiple CPU cores and one or more GPUs. Distributed jobs may run multiple ranks per node. Communication within a node is “intra-node” and across nodes it is “inter-node”.

Pipeline Parallel#

Multi GPU parallelism technique that distributes a model sequentially across several GPUs. Refer to PyTorch docs for more information.

Rank#

A single “unit” of computing in a distributed application. For AI models this typically is a single GPU. Many distributed programs will have multiple GPUs per node, therefore multiple ranks per node. Ranks are synonymous with an index (0, 1, … N-1) where N is the total number of processes.

Tip

In PhysicsNeMo, use the DistributedManager with dm.rank to get the numerical rank of a process. Similarly, dm.device will return the target device within each process. The total number of processes is dm.world_size.

Strong Scaling#

Keep total problem size fixed and increase the number of workers because ideal strong scaling reduces runtime inversely with the number of workers.

Tensor Parallel#

Intra-layer model parallelism that splits large operations (for example, matrix multiplies, attention heads) across devices and coordinates with collective communications. Most commonly supported through DeepSpeed and FSDP.

Note

Effective for very wide layers. It relies on collectives like all-gather/reduce-scatter.

Weak Scaling#

Increase total problem size proportionally with the number of workers while keeping work per worker constant. Ideal weak scaling holds runtime roughly constant as the problem grows.

Domain Parallelism