Source code for nemo_automodel.distributed.cp_utils

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import contextlib
from typing import List, Set
from torch.distributed.device_mesh import DeviceMesh

[docs] def _build_position_ids(batch, device): """Add position_ids to the batch only if they are missing.""" # TODO(@boxiangw): Refractor. Needed for SP support # If 'position_ids' does not exist in batch already then override it. # In case of Packed sequence contains 'position_ids' and we don't want to override it. if 'position_ids' not in batch: seq_len = batch['input_ids'].shape[1] batch['position_ids'] = torch.arange(seq_len, device=device).unsqueeze(0) return batch
# based on https://github.com/pytorch/torchtitan/blob/0b44d4c437c424b6bf719661c0eb4283dc4068bc/torchtitan/distributed/utils.py#L180 # pylint: disable=C0301
[docs] def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool, cp_context=None): """ Create a train context. Args: enable_loss_parallel (bool): Whether to enable loss parallelism. enable_compiled_autograd (bool): Whether to enable compiled autograd. """ @contextlib.contextmanager def context(): with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) if enable_compiled_autograd: stack.enter_context( torch._dynamo.utils.maybe_enable_compiled_autograd(True) ) if cp_context is not None: from torch.nn.attention import SDPBackend, sdpa_kernel # currently we only support these two SDP backends. # SDPBackend.MATH is not currently compatible with DTensor stack.enter_context( sdpa_kernel( [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] ) ) stack.enter_context(cp_context) yield return context
# based on https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/utils.py#L113
[docs] def create_context_parallel_ctx( cp_mesh: DeviceMesh, cp_buffers: List[torch.Tensor], cp_seq_dims: List[int], cp_no_restore_buffers: Set[torch.Tensor], cp_rotate_method: str, ): """ Create a context parallel context. Args: cp_mesh (DeviceMesh): The device mesh for context parallel. cp_buffers (List[torch.Tensor]): The buffers for context parallel. cp_seq_dims (List[int]): The sequence dimensions for context parallel. cp_no_restore_buffers (Set[torch.Tensor]): The no restore buffers for context parallel. cp_rotate_method (str): The rotation method for context parallel, such as "allgather" or "addtoall". """ from torch.distributed.tensor.experimental import context_parallel # TODO: uncomment this when torch.distributed.tensor.experimental._attention.set_rotate_method # is available # from torch.distributed.tensor.experimental._attention import set_rotate_method # set_rotate_method(cp_rotate_method) return context_parallel( cp_mesh, buffers=cp_buffers, buffer_seq_dims=cp_seq_dims, no_restore_buffers=cp_no_restore_buffers, )
[docs] def make_cp_batch_and_ctx(device_mesh, batch, labels, loss_mask): """ Build a CP context manager and shards a batch. If the input device_mesh is None or the size of the context_parallel submesh is 1, this function is effectively a no-op. Args: cp_mesh (DeviceMesh): The device mesh for context parallel. batch (Dict[str, torch.Tensor]): The input batch containing (string, torch.Tensor) Returns: tuple (contextmanager, dict[str, torch.Tensor]): Returns a tuple with a context manager and a new batch. The context manager is either nullcontext (no CP) or CP context manager as returned by `create_context_parallel_ctx`. The batch has also been passed to `create_context_parallel_ctx` and is accordingly sharded. """ from contextlib import nullcontext if device_mesh is None: cp_mesh = None else: cp_mesh = device_mesh["context_parallel"] if cp_mesh is None or cp_mesh.size() == 1: return nullcontext, batch input_ids = batch["input_ids"] position_ids = batch["position_ids"] if loss_mask is not None: cp_buffers = [input_ids, labels, position_ids, loss_mask] cp_seq_dims = [1, 1, 1, 1] cp_no_restore_buffers = {input_ids, labels, loss_mask} else: cp_buffers = [input_ids, labels, position_ids] cp_seq_dims = [1, 1, 1] cp_no_restore_buffers = {input_ids, labels} cp_ctx = create_context_parallel_ctx( cp_mesh=cp_mesh, cp_buffers=cp_buffers, cp_seq_dims=cp_seq_dims, cp_no_restore_buffers=cp_no_restore_buffers, cp_rotate_method="allgather", # TODO: expose through cfg ) # TODO(@akoumparouli): surface these in the future. enable_loss_parallel: bool = False enable_compiled_autograd: bool = False return get_train_context(enable_loss_parallel, enable_compiled_autograd, cp_ctx), batch