Source code for nemo_automodel.distributed.ddp

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from dataclasses import dataclass, field

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP


[docs] @dataclass class DDPManager: """ Manages setting up distributed training using PyTorch's DDP. Attributes: backend (str): The distributed backend to use (e.g. "nccl" or "gloo"). Defaults to "nccl". rank (int): Global rank of this process. This is set during distributed setup. world_size (int): Total number of processes in the distributed group. Set at distributed setup. """ backend: str = field( default="nccl", metadata={"help": "Distributed backend, e.g. 'nccl' or 'gloo'."} ) world_size: int = field( default_factory=lambda: int, metadata={"help": "Total number of distributed processes."} ) # This is populated in setup_distributed(), not by user: rank: int = field( init=False, default_factory=lambda: int, metadata={"help": "Global rank of this process."} )
[docs] def setup_distributed(self): """ Initialize the torch.distributed process group and set up device configuration. This method requires the following environment variables to be set: - RANK: Global rank of the process. - WORLD_SIZE: Total number of processes. - MASTER_ADDR: Address of the master node. - MASTER_PORT: Port on which the master node is listening. The method sets the `rank` and `world_size` of the DDPManager, configures the device (GPU for 'nccl' backend, CPU otherwise), and initializes the process group. """ if not dist.is_initialized(): rank = int(os.environ["RANK"]) world = int(os.environ["WORLD_SIZE"]) os.environ.setdefault("MASTER_ADDR", os.environ.get("MASTER_ADDR", "localhost")) os.environ.setdefault("MASTER_PORT", os.environ.get("MASTER_PORT", "29500")) dist.init_process_group(self.backend, rank=rank, world_size=world) self.rank = dist.get_rank() self.world_size = dist.get_world_size() # Pin GPU if using NCCL if self.backend == "nccl": local_gpu = self.rank % torch.cuda.device_count() torch.cuda.set_device(local_gpu) self.device = torch.device("cuda", index=local_gpu) else: self.device = torch.device("cpu")
[docs] def wrap_model(self, model): """ Wraps the given model with DistributedDataParallel (DDP). Moves the model to the initialized device before wrapping. For CUDA devices, the device id is passed to DDP as device_ids; for CPU, no device ids are provided. Args: model (torch.nn.Module): The PyTorch model to be wrapped. Returns: torch.nn.parallel.DistributedDataParallel: The DDP-wrapped model. """ return DDP( model.to(self.device), device_ids=[self.device] if self.device.type == "cuda" else None )
# @contextmanager # def no_sync(self): # """ # Context manager to temporarily disable gradient synchronization during backpropagation. # # This can be used for gradient accumulation: # with manager.no_sync(): # loss.backward() # # When used within a DDP-wrapped model, it skips the gradient all‐reduce. # """ # if isinstance(self.model, DDP): # with self.model.no_sync(): # yield # else: # yield