Source code for nemo_automodel.distributed.fsdp2

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    RowwiseParallel,
    SequenceParallel,
)
from torch.distributed.tensor.placement_types import Replicate, Shard

from nemo_automodel.distributed.parallelizer import (
    fsdp2_strategy_parallelize,
    get_hf_tp_shard_plan,
)


[docs] @dataclass class FSDP2Manager: """ Manager for setting up and parallelizing models using FSDP2 with TP, DP, CP sharding. This manager initializes the torch.distributed process group, infers the group sizes for data parallelism (DP) and tensor parallelism (TP), builds the device mesh for distributed operations, and applies parallelization to the model using a prescribed TP sharding plan. It also supports mixed precision and CPU offloading options. Attributes: dp_size (Optional[int]): Data-parallel group size. If None or non-positive, it is inferred from WORLD_SIZE. tp_size (Optional[int]): Tensor-parallel group size. Defaults to 1 if zero/None. cp_size (int): Context-parallel group size for pipeline-like sharding. sequence_parallel (bool): Enables sequence parallelism in the TP plan when True. mp_policy (MixedPrecisionPolicy): Defines the mixed precision policy for parameters, reductions, and outputs. offload_policy (CPUOffloadPolicy): Policy to offload parameters or optimizer states to CPU, if specified. backend (str): Distributed backend to use (e.g., 'nccl' for GPUs or 'gloo' for CPUs). world_size (int): Total number of processes. Methods: __post_init__(): Automatically sets up the distributed environment after initialization. _setup_distributed(): Initializes the torch.distributed process group, infers parallel sizes, builds the device mesh, and registers a destroy handler. parallelize(model): Applies FSDP2 and Tensor-Parallel sharding strategies to the given model. """ dp_size: Optional[int] = field( default=None, metadata={"help": "Data-parallel group size; if None, infer from WORLD_SIZE."}, ) tp_size: Optional[int] = field( default=1, metadata={"help": "Tensor-parallel group size; if None, defaults to 1."}, ) cp_size: Optional[int] = field( default=1, metadata={"help": "Context-parallel group size (for pipeline-like sharding)."}, ) sequence_parallel: Optional[bool] = field( default=False, metadata={"help": "Enable sequence parallelism in TP plan if True."}, ) mp_policy: Optional[MixedPrecisionPolicy] = field( default=MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, output_dtype=torch.bfloat16, cast_forward_inputs=True, ), metadata={ "help": "MixedPrecisionPolicy for FSDP2 (param/reduce/output dtypes)." }, ) offload_policy: Optional[CPUOffloadPolicy] = field( default=None, metadata={ "help": "CPUOffloadPolicy to offload parameters/optim states to CPU." }, ) backend: Optional[str] = field( default="nccl", metadata={"help": "Distributed backend, e.g. 'nccl' or 'gloo'."} ) world_size: Optional[int] = field( default=None, # init=False, metadata={"help": "Total number of processes."}, )
[docs] def __post_init__(self): """ Post-initialization hook that sets up the distributed environment. """ return self._setup_distributed()
[docs] def _setup_distributed(self): """ Initializes the distributed environment. - Checks availability and initialization of torch.distributed. - Infers data-parallel and tensor-parallel sizes if not provided. - Builds a device mesh based on the specified mesh shape and dimension names. - Flattens data and context dimensions if context parallelism is enabled. Requires the environment variables: RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT. Raises: RuntimeError: If torch.distributed is not available or not initialized. Returns: FSDP2Manager: Instance with the device mesh configured. """ if not dist.is_available(): raise RuntimeError("torch.distributed not available") if not dist.is_initialized(): raise RuntimeError("expected torch.distributed to be initialized") # infer if not provided self.dp_size = self.dp_size if self.dp_size is None or self.dp_size <= 0: self.dp_size = self.world_size self.tp_size = self.tp_size or 1 mesh_shape = (self.dp_size, self.cp_size, self.tp_size) mesh_names = ("data_parallel", "context_parallel", "tensor_parallel") for shape, name in zip(mesh_shape, mesh_names): assert isinstance( shape, int ), "Expected {} to be an int, but got {}".format(name, type(shape)) assert shape > 0, "Expected {} > 0, {}".format(name, shape) # build mesh [dp, cp, tp] self.device_mesh = init_device_mesh( device_type="cuda" if self.backend == "nccl" else "cpu", mesh_shape=mesh_shape, mesh_dim_names=mesh_names, ) # flatten dp+cp if cp>1 if self.cp_size > 1: self.device_mesh[("data_parallel", "context_parallel")]._flatten( mesh_dim_name="dp_cp" ) return self
[docs] def parallelize(self, model, use_hf_tp_plan=False): """ Parallelizes the given model using FSDP2 and TP sharding strategies. This method must be called after the distributed environment has been set up. It selects a TP sharding plan (currently supporting Hugging Face TP plan via get_hf_tp_shard_plan) and applies the FSDP2 parallelization strategy. Args: model (nn.Module): The model to be parallelized. use_hf_tp_plan (bool): if true, will attempt to get the TP plan from the model. Returns: The parallelized model. Raises: NotImplemented: If the required TP sharding plan is not supported. """ if self.device_mesh["tensor_parallel"].size() > 1: if use_hf_tp_plan: tp_shard_plan = get_hf_tp_shard_plan(model) else: # Parallelize the first embedding and the last linear out projection base_model_tp_plan = { "model.embed_tokens": RowwiseParallel(input_layouts=Replicate()), "model.layers.*.self_attn.q_proj": ColwiseParallel(), "model.layers.*.self_attn.k_proj": ColwiseParallel(), "model.layers.*.self_attn.v_proj": ColwiseParallel(), "model.layers.*.self_attn.o_proj": RowwiseParallel(), "model.layers.*.mlp.up_proj": ColwiseParallel(), "model.layers.*.mlp.gate_proj": ColwiseParallel(), "model.layers.*.mlp.down_proj": RowwiseParallel(), "lm_head": ColwiseParallel(output_layouts=Replicate()), } base_model_sp_plan = { "model.embed_tokens": RowwiseParallel( input_layouts=Replicate(), output_layouts=Shard(1) ), "model.norm": SequenceParallel(), "model.layers.*.input_layernorm": SequenceParallel(), "model.layers.*.self_attn.o_proj": RowwiseParallel( output_layouts=Shard(1) ), "model.layers.*.post_attention_layernorm": SequenceParallel(), "model.layers.*.mlp.down_proj": RowwiseParallel( output_layouts=Shard(1) ), "lm_head": ColwiseParallel( input_layouts=Shard(1), output_layouts=Replicate() ), } if self.sequence_parallel: # Enable sequence parallelism only if TP size > 1 base_model_tp_plan.update(base_model_sp_plan) tp_shard_plan = base_model_tp_plan # TODO(boxiangw): Change this to a log if self.device_mesh.get_rank() == 0: print( "Using default TP plan for parallelization. " "It is compatible with huggingface llama3-style models." ) else: tp_shard_plan = None fsdp2_strategy_parallelize( model, device_mesh=self.device_mesh, mp_policy=self.mp_policy, tp_shard_plan=tp_shard_plan, offload_policy=self.offload_policy, ) return model