Domain Decomposition, ShardTensor, and FSDP Tutorial#
In this tutorial, we will see how to combine domain parallelism,
ShardTensor
, and a training or inference recipe. Before starting this
tutorial, we recommend that you read the other domain parallelism
tutorials:
This tutorial demonstrates how to use PhysicsNeMo’s ShardTensor
functionality alongside PyTorch’s FSDP
(Fully Sharded Data Parallel)
to train or evaluate a simple ViT. Here’s what’s in the tutorial:
ViT Model Overview
Benchmarking the ViT on a single GPU
Enabling domain parallelism with
ShardTensor
Training and evaluating the model with domain parallelism
Basic ViT Model#
The model we’ll use for this tutorial is a straightforward ViT. It’s very similar to the original vision transformer from “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale” Dosovitskiy et al.. The model consists of two main conceptual pieces:
a convolutional tokenizer: it is a convolution with stride==kernel_size (so, non-overlapping image pieces) followed by a reshape to a sequence-like tensor with channels last.
a transformer block with residual attention and a residual MLP.
The overall model architecture is straightforward. The input image is tokenized using the convolutional tokenizer, a positional embedding is added, and then a series of transformer blocks are applied. At the end of the transformer layers, all of the tokens are averaged together. The entire architecture has one final layer to project the embedding dimension onto the output dimension.
Note
This isn’t really how you might implement a transformer for a vision classification task in practice - there are better, more sophisticated techniques. Since the original ViT publication, technical advances such as Convolution Transformers, Shifted Windows, Neighborhood Attention, and others have outperformed basic ViTs like this for classification. We encourage you to pick the model architecture most suitable for your task. To demonstrate the domain parallel techniques, we’ve picked a “Standard” vision transformer here.
Here’s the core of the model:
Model Implementation
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from .PatchEmbed2d import PatchEmbedding2d
from .PatchEmbed3d import PatchEmbedding3d
from .TransformerBlock import TransformerBlock
class HybridViT(nn.Module):
"""
Hybrid Vision Transformer with conv patch embedding and multiple transformer layers.
Args:
img_size: Input image size
patch_size: Size of patches for tokenization
in_channels: Number of input channels
num_classes: Number of classes for classification
embed_dim: Embedding dimension (same for all layers)
num_heads: Number of attention heads for each stage
depth: Number of transformer layers
mlp_ratio: MLP ratios for each layer
qkv_bias: Whether to use bias in QKV projections
"""
def __init__(
self,
img_size: int = [256, 256],
patch_size: int = 8,
in_channels: int = 3,
num_classes: int = 1000,
embed_dim: int = 768,
num_heads: int = 6,
depth: int = 16,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
) -> None:
super().__init__()
# Use the image size to select the padding:
if len(img_size) == 2:
self.patch_embed = PatchEmbedding2d(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
)
elif len(img_size) == 3:
self.patch_embed = PatchEmbedding3d(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
)
# Positional embeddings (for patches + CLS token)
self.pos_embed = nn.Parameter(
torch.zeros(1, self.patch_embed.num_patches, embed_dim)
)
# Build transformer stages (all operating on same resolution)
self.stages = nn.ModuleList(
[
TransformerBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
)
for _ in range(depth)
]
)
# Classification head
self.head = (
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
"""Extract features through all stages.
Args:
x: Input tensor of shape (B, C, H, W)
Returns:
CLS token features of shape (B, embed_dim)
"""
B = x.shape[0]
# Patch embedding
x = self.patch_embed(x) # B, N, C
# Add positional embeddings
x = x + self.pos_embed
# Apply transformer stages
for stage in self.stages:
x = stage(x)
# Return the mean of all tokens
return x.mean(dim=(1,))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Full forward pass for classification.
Args:
x: Input tensor of shape (B, C, H, W)
Returns:
Classification logits of shape (B, num_classes)
"""
x = self.forward_features(x)
x = self.head(x)
return x
For more information on the components, expand the following sections to see the code:
Patch Embedding Implementations
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from einops import rearrange
class PatchEmbedding2d(nn.Module):
"""Single patch embedding layer that tokenizes and embeds input 2D images."""
def __init__(
self,
img_size: tuple[int],
patch_size: int = 16,
in_channels: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
for i in img_size:
assert i % patch_size == 0, (
f"Image size {i} must be divisible by patch size {patch_size}"
)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
# Single convolution that acts as both tokenizer and linear embedding
self.conv = nn.Conv2d(
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Convert image to patch embeddings.
Args:
x: Input tensor of shape (B, C, H, W)
Returns:
Patch embeddings of shape (B, num_patches, embed_dim)
"""
x = self.conv(x)
# Rearrange to apply LayerNorm correctly: BCHW -> B(HW)C
x = rearrange(x, "b c h w -> b (h w) c")
x = self.norm(x)
# Keep in BHWC format for efficient downstream processing
x = nn.functional.relu(x)
return x
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from einops import rearrange
class PatchEmbedding3d(nn.Module):
"""Single patch embedding layer that tokenizes and embeds input 3D images."""
def __init__(
self,
img_size: tuple[int],
patch_size: int = 16,
in_channels: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
for i in img_size:
assert i % patch_size == 0, (
f"Image size {i} must be divisible by patch size {patch_size}"
)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (
(img_size[0] // patch_size)
* (img_size[1] // patch_size)
* (img_size[2] // patch_size)
)
# Single convolution that acts as both tokenizer and linear embedding
self.conv = nn.Conv3d(
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Convert image to patch embeddings.
Args:
x: Input tensor of shape (B, C, H, W, D)
Returns:
Patch embeddings of shape (B, num_patches, embed_dim)
"""
x = self.conv(x)
# Rearrange to apply LayerNorm correctly: BCHWD -> B(HWD)C
x = rearrange(x, "b c h w d -> b (h w d) c")
x = self.norm(x)
# Keep in BHWC format for efficient downstream processing
x = nn.functional.relu(x)
return x
Transformer Block
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from .MultiHeadAttention import MultiHeadAttention
from .MLP import MLP
class TransformerBlock(nn.Module):
"""Standard transformer block with multi-head attention and MLP."""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(
in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply transformer block with residual connections.
Args:
x: Input tensor of shape (B, N, C)
Returns:
Transformed tensor of shape (B, N, C)
"""
# Attention block with residual connection
x = x + self.attn(self.norm1(x))
# MLP block with residual connection
x = x + self.mlp(self.norm2(x))
return x
MLP
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
class MLP(nn.Module):
"""MLP as used in Vision Transformer."""
def __init__(
self, in_features: int, hidden_features: int, out_features: int
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
# Two-layer MLP with activation
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply MLP transformation.
Args:
x: Input tensor of shape (B, N, C)
Returns:
Transformed tensor of shape (B, N, out_features)
"""
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
Multi-head Attention
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
class MultiHeadAttention(nn.Module):
"""Standard multi-head attention using PyTorch's scaled_dot_product_attention."""
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False) -> None:
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
# Combined QKV projection for efficiency
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply multi-head self-attention.
Args:
x: Input tensor of shape (B, N, C)
Returns:
Attention output of shape (B, N, C)
"""
B, N, C = x.shape
# Project to Q, K, V and reshape for multi-head attention
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2] # B, num_heads, N, head_dim
# Use PyTorch's optimized scaled dot product attention
x = nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=0.0, is_causal=False
)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
Running the ViT#
The training script for this tutorial has no data or labels, only synthetic
data. We loop over image sizes, initialize the ViT model, and then evaluate
its performance (computational performance, not model accuracy)
using a basic loop. We measure both inference and training performance
using torch.cuda.Event
objects to capture timing information and
average over a few iterations. Each of those pieces has been packaged
into basic functions so that you can run and reproduce this code:
How to measure model performance
torch.cuda.Event
objects to
capture timing information, and average over a few iterations.## SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torch.amp import autocast
import contextlib
import numpy as np
def benchmark_model(
model,
x,
target,
optimizer,
num_warmup=5,
num_iterations=10,
use_mixed_precision=False,
):
"""Benchmark forward pass and training step performance.
Args:
model: The model to benchmark
x: Input tensor
target: Target tensor for loss computation
optimizer: Optimizer for training step
num_warmup: Number of warmup iterations
num_iterations: Number of benchmark iterations
use_mixed_precision: Whether to use mixed precision training
Returns:
Tuple of (forward_time, training_time) in seconds
"""
# Making a flexible context here to enable us to flip mixed precision on/off easily.
if use_mixed_precision:
context = autocast("cuda")
else:
context = contextlib.nullcontext()
# HEADS UP:
# You would use a grad scalar to do stable mixed precision in real training!
# https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
# With only a few iterations of training here, on synthetic data, we won't worry about it.
# Warmup runs
for _ in range(num_warmup):
# Inference only
with torch.no_grad():
with context:
_ = model(x)
# Training warmup step
optimizer.zero_grad()
with context:
output = model(x)
loss = nn.CrossEntropyLoss()(output, target)
loss.backward()
optimizer.step()
# Benchmark forward pass
torch.cuda.synchronize()
forward_times = []
for _ in range(num_iterations):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
with torch.no_grad():
with context:
_ = model(x)
end_event.record()
torch.cuda.synchronize()
elapsed_time = (
start_event.elapsed_time(end_event) / 1000.0
) # Convert ms to seconds
forward_times.append(elapsed_time)
avg_forward_time = np.mean(forward_times)
# Benchmark training step
torch.cuda.synchronize()
training_times = []
for _ in range(num_iterations):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
optimizer.zero_grad()
with context:
output = model(x)
loss = nn.CrossEntropyLoss()(output, target)
loss.backward()
optimizer.step()
end_event.record()
torch.cuda.synchronize()
elapsed_time = (
start_event.elapsed_time(end_event) / 1000.0
) # Convert ms to seconds
training_times.append(elapsed_time)
avg_training_time = np.mean(training_times)
return avg_forward_time, avg_training_time
Measuring memory usage
torch.cuda.reset_peak_memory_stats()
and
torch.cuda.max_memory_allocated()
to measure memory usage.## SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torch.amp import autocast
import contextlib
def get_model_memory_usage(
model, x, target=None, optimizer=None, mode="inference", use_mixed_precision=False
):
"""Estimate model memory usage for inference or training.
Args:
model: The model to measure
x: Input tensor
target: Target tensor (required for training mode)
optimizer: Optimizer (required for training mode)
mode: 'inference' or 'training'
use_mixed_precision: Whether to use mixed precision
Returns:
Peak memory usage in GB
"""
if use_mixed_precision:
context = autocast("cuda")
else:
context = contextlib.nullcontext()
torch.cuda.reset_peak_memory_stats()
if mode == "inference":
with torch.no_grad():
with context:
_ = model(x)
elif mode == "training":
if target is None or optimizer is None:
raise ValueError("target and optimizer must be provided for training mode")
optimizer.zero_grad()
with context:
output = model(x)
loss = nn.CrossEntropyLoss()(output, target)
loss.backward()
return torch.cuda.max_memory_allocated() / 1024**3 # GB
End to End Benchmarking
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.optim as optim
from .measure_perf import benchmark_model
from .measure_memory import get_model_memory_usage
def end_to_end_benchmark(args, model, inputs, full_img_size, device, num_classes):
x, target = inputs
# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# Create optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
try:
# Benchmark model
forward_time, training_time = benchmark_model(
model, x, target, optimizer, use_mixed_precision=args.use_mixed_precision
)
# Memory usage - measure both inference and training
inference_memory = get_model_memory_usage(
model, x, mode="inference", use_mixed_precision=args.use_mixed_precision
)
training_memory = get_model_memory_usage(
model,
x,
target,
optimizer,
mode="training",
use_mixed_precision=args.use_mixed_precision,
)
# Store results
results = {
"image_size": full_img_size[0],
"params": num_params,
"forward_time": forward_time,
"training_time": training_time,
"inference_memory": inference_memory,
"training_memory": training_memory,
"mixed_precision": args.use_mixed_precision and torch.cuda.is_available(),
}
except RuntimeError as e:
print(f" Error: {e}")
# Store failed result
results = {
"image_size": full_img_size[0],
"params": num_params,
"forward_time": float("inf"),
"training_time": float("inf"),
"inference_memory": float("inf"),
"training_memory": float("inf"),
"mixed_precision": args.use_mixed_precision and torch.cuda.is_available(),
}
# Clear cache to free memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
del model, optimizer
return results
Users of PyTorch’s DDP
are familiar with the techniques of wrapping their
model with a DDP
object, rather than making modifications to their
model directly. To minimize the amount of model code modification you must
do, ShardTensor
follows the same philosophy as DDP
: users should not
have to modify their model code or training loops significantly to enable
domain paralellism with ShardTensor
.
The rest of the tutorial walks through the main script to highlight how components of the script change to enable domain parallelism.
Setting Up the Environment#
There are extra imports for DDP
, ShardTensor
, and FSDP
:
import torch
import torch.nn as nn
import torch
import torch.nn as nn
# Use PhyscicsNeMo's distributed manager to simplify initialization
from physicsnemo.distributed import DistributedManager
# Add DDP import
from torch.nn.parallel import DistributedDataParallel as DDP
import torch
import torch.nn as nn
# Use PhyscicsNeMo's distributed manager to simplify initialization
from physicsnemo.distributed import DistributedManager
# Upstream imports for FSDP:
from torch.distributed.tensor import distribute_module, distribute_tensor
# FSDP instead of DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.placement_types import ( # noqa: E402
Replicate,
Shard,
)
# PhysicsNeMo imports to turn your inputs into ShardTensors
from physicsnemo.distributed import scatter_tensor
Run Configuration#
The configuration is the same for all three cases:
args = parse_args()
image_sizes = list(range(args.image_size_start, args.image_size_stop + 1, args.image_size_step))
device = torch.device('cuda')
# Generate image sizes based on start, stop, and step
if args.dimension == 2:
image_sizes = list(range(args.image_size_start, args.image_size_stop + 1, args.image_size_step))
elif args.dimension == 3:
image_sizes = list(range(args.image_size_start, min(args.image_size_stop + 1, 513), args.image_size_step))
# Should we use mixed precision?
precision_mode = "FP16" if args.use_mixed_precision and torch.cuda.is_available() else "FP32"
args = parse_args()
image_sizes = list(range(args.image_size_start, args.image_size_stop + 1, args.image_size_step))
device = torch.device('cuda')
# Generate image sizes based on start, stop, and step
if args.dimension == 2:
image_sizes = list(range(args.image_size_start, args.image_size_stop + 1, args.image_size_step))
elif args.dimension == 3:
image_sizes = list(range(args.image_size_start, min(args.image_size_stop + 1, 513), args.image_size_step))
# Should we use mixed precision?
precision_mode = "FP16" if args.use_mixed_precision and torch.cuda.is_available() else "FP32"
args = parse_args()
image_sizes = list(range(args.image_size_start, args.image_size_stop + 1, args.image_size_step))
device = torch.device('cuda')
# Generate image sizes based on start, stop, and step
if args.dimension == 2:
image_sizes = list(range(args.image_size_start, args.image_size_stop + 1, args.image_size_step))
elif args.dimension == 3:
image_sizes = list(range(args.image_size_start, min(args.image_size_stop + 1, 513), args.image_size_step))
# Should we use mixed precision?
precision_mode = "FP16" if args.use_mixed_precision and torch.cuda.is_available() else "FP32"
Distributed Configuration#
Here physicsnemo.distributed.DistributedManager
is used to set up the 1D or 2D parallelization:
# Initialize distributed manager first
DistributedManager.initialize()
dm = DistributedManager()
# Set device based on local rank
device = dm.device
torch.cuda.set_device(device)
# Initialize distributed manager first
DistributedManager.initialize()
dm = DistributedManager()
# Set via commandline and argparse:
ddp_size = args.ddp_size
domain_size = args.domain_size
# Set device based on local rank
device = dm.device
torch.cuda.set_device(device)
# Initialize distributed manager first
DistributedManager.initialize()
dm = DistributedManager()
# Set via commandline and argparse:
ddp_size = args.ddp_size
domain_size = args.domain_size
# Set device based on local rank
device = dm.device
torch.cuda.set_device(device)
# Use the physics nemo distribute manager to quickly and easily set up a pytorch DeviceMesh:
mesh = dm.initialize_mesh(
mesh_shape=(ddp_size, domain_size,), # -1 works the same way as reshaping
mesh_dim_names = ["ddp","domain"]
)
ddp_mesh = mesh["ddp"]
domain_mesh = mesh["domain"]
Preparing the Inputs#
We use synthetic inputs for this tutorial. The global batch size is assumed to be configured on the command line - and when using domain parallelism, divide the global batch size by the number of model replications.
For 2D parallelism, we divide the global batch size by the replicate count, but also apply a scatter to shard each single example across multiple GPUs.
Because we are parallelizing over the batch and domain, one batch of data
is scattered over an image axis. Review Shard(2)
below. Review
the BCHW(D)
format in PyTorch. In this example, we are targeting H
:
if args.dimension == 2:
full_img_size = (img_size, img_size)
elif args.dimension == 3:
full_img_size = (img_size, img_size, img_size)
if args.dimension == 2:
full_img_size = (img_size, img_size)
elif args.dimension == 3:
full_img_size = (img_size, img_size, img_size)
# Create synthetic data - scale the batch size down by DDP size.
x = torch.randn(args.batch_size // ddp_size, 3, * full_img_size, device=device)
target = torch.randint(0, num_classes, (args.batch_size // ddp_size,), device=device)
if args.dimension == 2:
full_img_size = (img_size, img_size)
elif args.dimension == 3:
full_img_size = (img_size, img_size, img_size)
# Create synthetic data - scale the batch size down by DDP size.
x = torch.randn(args.batch_size // ddp_size, 3, * full_img_size, device=device)
target = torch.randint(0, num_classes, (args.batch_size // ddp_size,), device=device)
# Domain Parallel NOTE: we're generating data once per GPU but only keeping the data once per domain.
# In a real application, you'd do this properly - each GPU would read its own shard of the data.
if args.domain_size > 1:
# When scattering the data, we need to know the global rank of the source
# But by definition, we use the domain_rank == 0 as the source. Convert:
global_rank_of_source = torch.distributed.get_global_rank(domain_mesh.get_group(), 0)
# Scatter the input data across the domain:
x = scatter_tensor(
x,
global_rank_of_source,
domain_mesh,
placements=(Shard(2),), # Shard along the 2nd dimension (B C **H** W) which is the Height
global_shape = x.shape, # This will be inferred if not provided!
dtype = x.dtype, # This will be inferred if not provided!
)
target = scatter_tensor(
target,
global_rank_of_source,
domain_mesh,
placements=(Replicate(),), # REPLICATE the target
global_shape = target.shape, # This will be inferred if not provided!
dtype = target.dtype, # This will be inferred if not provided!
)
Configure the Model#
To configure the model, build it as usual and then use some torch
functionality to distribute it across 1D or 2D parallelism:
# Base model
model = HybridViT(img_size = full_img_size, in_channels=3, num_classes=num_classes)
model = model.to(device)
# Base model
model = HybridViT(img_size = full_img_size, in_channels=3, num_classes=num_classes)
model = model.to(device)
# Wrap model with DDP
model = DDP(model, device_ids=[dm.local_rank], output_device=dm.local_rank)
# Base model
model = HybridViT(img_size = full_img_size, in_channels=3, num_classes=num_classes)
model = model.to(device)
# This step syncs across the domain only
model = distribute_module(
model,
device_mesh=domain_mesh,
partition_fn = partition_model, # See below to understand what this is!
)
# This step goes in the other axis on the mesh: every rank "i" of
# each domain will sync up here.
model = FSDP(model, device_mesh=ddp_mesh, use_orig_params=False)
Above, in the ShardTensor + FSDP column, you might have noticed the presence
of the partition_fn
argument in distribute_module
.
It lets you have full control over the way your model’s parameters are sharded
across the domain mesh. For more detail, refer to the
PyTorch docs.
Here, most of the parameters get replicated. However, because this ViT includes a learnable position, encoding that is the same size as the tokenized data that we are sharding. You can use the partition function to shard the embedding in the same way:
def partition_model(name, submodule, device_mesh):
for key, param in submodule._parameters.items():
if "pos_embed" in key:
# Replace the pos_embed with a scattered ShardTensor
# Global source is the global rank of local rank 0:
scattered_pos_embed = distribute_tensor(
submodule.pos_embed,
device_mesh=device_mesh,
placements=[
Shard(1),
],
)
submodule.register_parameter(key, torch.nn.Parameter(scattered_pos_embed))
The partition function is applied recursively to your module to find the
parameter named pos_embed
, shard it, and replace it in the original model.
By default, all parameters that aren’t converted here will get cast to DTensor, which is what we want for 2D parallelism.
After adding a few extra imports, setting up a DeviceMesh
, sharding the
inputs, and distributing the model, everything else proceeds as usual.
You can run the benchmark with the same code across all three implementations:
results = end_to_end_benchmark(args, model, (x, target), full_img_size, device, num_classes)
if dm.rank == 0:
print_and_save_results(results, args, precision_mode, dm.world_size)
results = end_to_end_benchmark(args, model, (x, target), full_img_size, device, num_classes)
if dm.rank == 0:
print_and_save_results(results, args, precision_mode, dm.world_size)
results = end_to_end_benchmark(args, model, (x, target), full_img_size, device, num_classes)
if dm.rank == 0:
print_and_save_results(results, args, precision_mode, dm.world_size)
Note
The full training script and all worker functions, configurable by domain size and DDP size, are available on PhysicNeMo GitHub examples.
Benchmark Results#
Benchmark results can be useful for deciding when to use ShardTensor
or
DDP. We recommend that you use ShardTensor when you can’t fit
batch_size==1
on a single GPU.
1024x1024 2D Image#
At a resolution of 1024 pixels on a side, our baseline ViT shows reasonable performance on a single GPU.
We can keep the per-GPU batch size fixed, scale out with DDP, and get very good scaling.
We can also scale in two directions and see that latency, at fixed global
batch size, decreases; however, ShardTensor
isn’t ideal in this regime:
Training Throughput (Images / second), at 1024 pixels per side, decreases with more GPUs per image (that is, using ShardTensor), but total throughput is highest with each GPU responsible for a full image.
GPUS / Image |
B=1 |
B=2 |
B=4 |
B=8 |
---|---|---|---|---|
1 |
0.46 |
0.91 |
1.8 |
3.6 |
2 |
0.76 |
1.6 |
3.1 |
|
4 |
1.3 |
2.7 |
||
8 |
1.9 |
Training Memory Usage (GB) At this resolution, the model uses only 14 GB of GPU memory per image out of the available 80 GB total.
GPUS / Image |
B=1 |
B=2 |
B=4 |
B=8 |
---|---|---|---|---|
1 |
13.9 |
14.4 |
14.4 |
14.4 |
2 |
7.6 |
7.4 |
7.1 |
|
4 |
4.5 |
4.2 |
||
8 |
2.9 |
ShardTensor
, in most operations, does add a little overhead. Most of the
kernels that benefit from domain parallelism require communication between
GPUs and efficiency increases as the computational size increases from 1024
squared to 2048 squared:
Latency per step (s) The processing time increases linearly with the number of tokens in each layer, but tokens scale as the resolution squared.
GPUs |
Inference 1024 |
Train 1024 |
Inference 2048 |
Train 2048 |
---|---|---|---|---|
1 |
0.55 |
7.96 |
2.2 |
31.4 |
2 |
0.32 |
4.13 |
1.32 |
16.4 |
4 |
0.19 |
2.23 |
0.76 |
8.78 |
8 |
0.13 |
1.33 |
0.54 |
5.02 |
Speedup After a certain data size, ShardTensor
is always
faster with more GPUs. But, larger images show bigger benefits.
GPUs |
Inference 1024 |
Train 1024 |
Inference 2048 |
Train 2048 |
---|---|---|---|---|
1 |
1.0 |
1.0 |
1.0 |
1.0 |
2 |
1.7 |
1.9 |
1.7 |
1.9 |
4 |
2.9 |
3.6 |
2.9 |
3.6 |
8 |
4.2 |
6.0 |
4.1 |
6.3 |
Memory Usage (GB) Like latency, memory usage in training scales roughly like the number of tokens. For inference, it’s driven mostly by model size.
GPUs |
Inference 1024 |
Train 1024 |
Inference 2048 |
Train 2048 |
---|---|---|---|---|
1 |
2.5 |
13.9 |
4.6 |
51.4 |
2 |
2.2 |
7.6 |
3.8 |
26.5 |
4 |
2.0 |
4.5 |
2.8 |
13.9 |
8 |
1.9 |
2.9 |
2.3 |
7.6 |
Memory Reduction (%) For highest resolution data, we obtain close to linear reduction in memory with more GPUs.
GPUs |
Inference 1024 |
Train 1024 |
Inference 2048 |
Train 2048 |
---|---|---|---|---|
1 |
100% |
100% |
100% |
100% |
2 |
88% |
55% |
83% |
52% |
4 |
80% |
32% |
61% |
27% |
8 |
76% |
21% |
50% |
15% |
If you are tracking the memory scaling performance of this model, you’ll see that the training memory at higher resolution is roughly proportional to the total number of pixels in the image. At 51.4 GB of training memory for 2048x2048 sized images, we expect the next doubling (4096x4096 pixels) to require more than 200 GB of memory per GPU.
Using ShardTensor
, we can run it out of the box on 8 GPUs and we see about
26 GB of memory used per GPU, as expected. You can also run large-scale 3D
vision models like this. However, because memory usage scales with the cube of
the resolution (rather than the square, as in 2D), memory issues arise
even faster.
Review of Tutorial Steps#
This tutorial covered the key steps to enable ShardTensor
in your model.
ShardTensor
performance and broad layer support are still evolving. Many
key models will work out of the box, while others contain operations that are
not yet fully supported. If you have specific requests for support, open an
issue on GitHub and
review the tutorial for Implementing New Layers for ShardTensor.
Summary of the Workflow for 2D Domain Parallelism
Define the Device Mesh
Split the mesh into two dimensions: one for data parallelism (
FSDP
) and one for spatial decomposition (ShardTensor
).Example:
mesh = dm.initialize_mesh((-1, 2), mesh_dim_names=["data", "spatial"])
For multilevel parallelism, the mesh can be extended to additional dimensions. A DeviceMesh can be conceptualized as an N-dimensional tensor where each element is one GPU, and each dimension of the tensor is one axis of parallelism.
Shard Input Data
Distribute the input tensor across the spatial dimension using
ShardTensor
.Handle Parameters
Use
FSDP
to shard parameters and optimize across the data-parallel dimension.Scale Spatial Dimensions
Larger spatial dimensions can be processed efficiently by distributing computation across devices.
You are now ready to scale your models and data to very high resolutions
using ShardTensor
.