Source code for nemo_automodel.distributed.nvfsdp

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import List, Optional

import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    RowwiseParallel,
)

from nemo_automodel.distributed.parallelizer import (
    get_hf_tp_shard_plan,
    nvfsdp_strategy_parallelize,
)


[docs] @dataclass class NVFSDPManager: """ Manager for setting up and parallelizing models using nvFSDP with TP, DP, CP sharding. This manager initializes the torch.distributed process group, infers the group sizes for data parallelism (DP) and tensor parallelism (TP), builds the device mesh for distributed operations, and applies parallelization to the model using a prescribed TP sharding plan. It also supports mixed precision and CPU offloading options. Attributes: dp_size (Optional[int]): Data-parallel group size. If None or non-positive, it is inferred from WORLD_SIZE. tp_size (Optional[int]): Tensor-parallel group size. Defaults to 1 if zero/None. cp_size (int): Context-parallel group size for pipeline-like sharding. sequence_parallel (bool): Enables sequence parallelism in the TP plan when True. backend (str): Distributed backend to use (e.g., 'nccl' for GPUs or 'gloo' for CPUs). world_size (int): Total number of processes. Methods: __post_init__(): Automatically sets up the distributed environment after initialization. _setup_distributed(): Initializes the torch.distributed process group, infers parallel sizes, builds the device mesh, and registers a destroy handler. parallelize(model): Applies FSDP2 and Tensor-Parallel sharding strategies to the given model. """ dp_size: Optional[int] = field( default=None, metadata={"help": "Data-parallel group size; if None, infer from WORLD_SIZE."}, ) tp_size: Optional[int] = field( default=1, metadata={"help": "Tensor-parallel group size; if None, defaults to 1."}, ) cp_size: Optional[int] = field( default=1, metadata={"help": "Context-parallel group size (for pipeline-like sharding)."}, ) sequence_parallel: Optional[bool] = field( default=False, metadata={ "help": "Enable sequence parallelism in TP plan if True. Not supported with nvFSDP right now." }, ) backend: Optional[str] = field( default="nccl", metadata={"help": "Distributed backend, e.g. 'nccl' or 'gloo'."} ) world_size: Optional[int] = field( default=None, # init=False, metadata={"help": "Total number of processes."}, ) nvfsdp_unit_modules: Optional[List[str]] = field( default_factory=lambda: [ "transformers.models.llama.modeling_llama.LlamaDecoderLayer", ], metadata={"help": "List of unit modules to be wrapped with nvFSDP."}, ) # nvFSDP config data_parallel_sharding_strategy: Optional[str] = field( default="optim_grads_params", metadata={"help": "Data parallel sharding strategy."}, ) init_nvfsdp_with_meta_device: Optional[bool] = field( default=False, metadata={"help": "Initialize nvFSDP with meta device if True."} ) grad_reduce_in_fp32: Optional[bool] = field( default=False, metadata={"help": "Reduce gradients in fp32 if True."} ) preserve_fp32_weights: Optional[bool] = field( default=False, metadata={"help": "Preserve fp32 weights if True."} ) overlap_grad_reduce: Optional[bool] = field( default=True, metadata={"help": "Overlap gradient reduction if True."} ) overlap_param_gather: Optional[bool] = field( default=True, metadata={"help": "Overlap parameter gathering if True."} ) check_for_nan_in_grad: Optional[bool] = field( default=True, metadata={"help": "Check for NaN in gradients if True."} ) average_in_collective: Optional[bool] = field( default=False, metadata={"help": "Average in collective if True."} ) disable_bucketing: Optional[bool] = field( default=False, metadata={"help": "Disable bucketing if True."} ) calculate_per_token_loss: Optional[bool] = field( default=False, metadata={"help": "Calculate per token loss if True."} ) keep_fp8_transpose_cache_when_using_custom_fsdp: Optional[bool] = field( default=False, metadata={"help": "Keep fp8 transpose cache when using custom FSDP if True."} ) nccl_ub: Optional[bool] = field( default=False, metadata={"help": "Use NCCL UBs if True."} ) fsdp_double_buffer: Optional[bool] = field( default=False, metadata={"help": "Use double buffer if True."} )
[docs] def __post_init__(self): """ Post-initialization hook that sets up the distributed environment. """ return self._setup_distributed()
[docs] def _setup_distributed(self): """ Initializes the distributed environment. - Checks availability and initialization of torch.distributed. - Infers data-parallel and tensor-parallel sizes if not provided. - Builds a device mesh based on the specified mesh shape and dimension names. - Flattens data and context dimensions if context parallelism is enabled. Requires the environment variables: RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT. Raises: RuntimeError: If torch.distributed is not available or not initialized. Returns: FSDP2Manager: Instance with the device mesh configured. """ if not dist.is_available(): raise RuntimeError("torch.distributed not available") if not dist.is_initialized(): raise RuntimeError("expected torch.distributed to be initialized") # infer if not provided self.dp_size = self.dp_size if self.dp_size is None or self.dp_size <= 0: self.dp_size = self.world_size self.tp_size = self.tp_size or 1 mesh_shape = (self.dp_size, self.cp_size, self.tp_size) mesh_names = ("data_parallel", "context_parallel", "tensor_parallel") for shape, name in zip(mesh_shape, mesh_names): assert isinstance( shape, int ), "Expected {} to be an int, but got {}".format(name, type(shape)) assert shape > 0, "Expected {} > 0, {}".format(name, shape) # build mesh [dp, cp, tp] self.device_mesh = init_device_mesh( device_type="cuda" if self.backend == "nccl" else "cpu", mesh_shape=mesh_shape, mesh_dim_names=mesh_names, ) # flatten dp+cp if cp>1 if self.cp_size > 1: self.device_mesh[("data_parallel", "context_parallel")]._flatten( mesh_dim_name="dp_cp" ) return self
[docs] def parallelize(self, model, use_hf_tp_plan=False): """ Parallelizes the given model using FSDP2 and TP sharding strategies. This method must be called after the distributed environment has been set up. It selects a TP sharding plan (currently supporting Hugging Face TP plan via get_hf_tp_shard_plan) and applies the FSDP2 parallelization strategy. Args: model: The model to be parallelized. use_hf_tp_plan (bool): if true, will query the model for the TP plan. Returns: The parallelized model. Raises: NotImplemented: If the required TP sharding plan is not supported. """ if self.data_parallel_sharding_strategy != "optim_grads_params": if self.device_mesh.get_rank() == 0: print( "Warning: nvFSDP data_parallel_sharding_strategy is not optim_grads_params. " "Parameters will not be sharded." ) if self.device_mesh["tensor_parallel"].size() > 1: if use_hf_tp_plan: tp_shard_plan = get_hf_tp_shard_plan(model) else: # Parallelize the first embedding and the last linear out projection base_model_tp_plan = { "model.layers.*.self_attn.q_proj": ColwiseParallel(), "model.layers.*.self_attn.k_proj": ColwiseParallel(), "model.layers.*.self_attn.v_proj": ColwiseParallel(), "model.layers.*.self_attn.o_proj": RowwiseParallel(), "model.layers.*.mlp.up_proj": ColwiseParallel(), "model.layers.*.mlp.gate_proj": ColwiseParallel(), "model.layers.*.mlp.down_proj": RowwiseParallel(), } # TODO(boxiangw): investigate SP if self.sequence_parallel and self.device_mesh.get_rank() == 0: # TODO(boxiangw): Change this to a log print( "Sequence parallelism is disabled. It is not compatible with nvFSDP." ) tp_shard_plan = base_model_tp_plan # TODO(boxiangw): Change this to a log if self.device_mesh.get_rank() == 0: print( "Using default TP plan for parallelization. " "It is compatible with huggingface llama3-style models." ) else: tp_shard_plan = None model = nvfsdp_strategy_parallelize( model, device_mesh=self.device_mesh, nvfsdp_unit_modules=self.nvfsdp_unit_modules, tp_shard_plan=tp_shard_plan, data_parallel_sharding_strategy=self.data_parallel_sharding_strategy, init_nvfsdp_with_meta_device=self.init_nvfsdp_with_meta_device, grad_reduce_in_fp32=self.grad_reduce_in_fp32, preserve_fp32_weights=self.preserve_fp32_weights, overlap_grad_reduce=self.overlap_grad_reduce, overlap_param_gather=self.overlap_param_gather, check_for_nan_in_grad=self.check_for_nan_in_grad, average_in_collective=self.average_in_collective, disable_bucketing=self.disable_bucketing, calculate_per_token_loss=self.calculate_per_token_loss, keep_fp8_transpose_cache_when_using_custom_fsdp=self.keep_fp8_transpose_cache_when_using_custom_fsdp, nccl_ub=self.nccl_ub, fsdp_double_buffer=self.fsdp_double_buffer, ) return model