#!/usr/bin/python3
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from contextlib import ContextDecorator, nullcontext
from datetime import datetime
import torch
from typing import Optional
import torch.distributed
import torch.distributed as dist
logger = logging.getLogger(__name__)
[docs]
class FirstRankPerNode(ContextDecorator):
"""
Context manager to enforce rank0 to process section over other ranks.
- Lets LOCAL_RANK==0 run the protected code first on each node.
- Inserts an extra barrier across *only* the node‑local rank‑0 processes.
- Works on a single GPU (no env flags, no distributed initialisation).
Note: it is assumed the scoped code is not torch.distributed heavy.
"""
[docs]
def __enter__(self):
"""
Create / bootstrap a (distributed) proc. group that rank0 enters first.
Returns:
bool: ``True`` – if the current process is node-rank-0
``False`` – otherwise
"""
self._created_pg = False
self._node0_group = None
self._first = True # default for single‑GPU / no‑dist case
# ------------------------------------------------------------------ #
# 1. Make sure there is at least *some* process‑group initialised
# ------------------------------------------------------------------ #
if not dist.is_initialized():
self._created_pg = self._try_bootstrap_pg()
if not dist.is_initialized():
# pure single GPU
return True
# ------------------------------------------------------------------ #
# 2. Figure out local/global ranks
# ------------------------------------------------------------------ #
env = os.environ
global_rank = dist.get_rank()
local_rank = int(env.get("LOCAL_RANK", global_rank)) # fallback
self._first = local_rank == 0
# ------------------------------------------------------------------ #
# 3. Synchronisation logic
# ------------------------------------------------------------------ #
if not self._first:
# Non‑rank‑0 processes wait for their node‑rank-0
dist.barrier()
return self._first
[docs]
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Tear down the context.
1. If the current process was the first on its node, release the
waiting peer ranks by issuing a barrier.
2. If an exception occurred, abort the *entire* distributed job.
3. If this context manager created the process group, destroy it.
Args:
exc_type (Type[BaseException] | None): Exception class if one
occurred inside the ``with`` block.
exc_val (BaseException | None): The raised exception instance.
exc_tb (TracebackType | None): Traceback associated with the
exception.
Returns:
bool: ``False`` so that any exception raised inside the ``with``
block is propagated to the caller (standard CM semantics).
"""
try:
if self._first and dist.is_initialized():
# Re‑sync the whole world so that non‑rank‑0s can proceed
dist.barrier()
if exc_type is not None:
dist.abort() # propagate failure to the entire job
finally:
if self._created_pg:
dist.destroy_process_group()
# propagate any exception to outer scope
return False
[docs]
def _try_bootstrap_pg(self) -> bool:
"""
Try to create a default pg from env:// variables.
"""
env = os.environ
required = ("WORLD_SIZE", "RANK", "MASTER_ADDR", "MASTER_PORT")
if all(k in env for k in required):
dist.init_process_group(
backend="gloo",
world_size=int(env.get("WORLD_SIZE")),
rank=int(env.get("RANK")),
)
return True
return False
[docs]
def get_rank_safe() -> int:
"""
Get the distributed rank safely, even if torch.distributed is not initialized.
Returns:
The current process rank.
"""
# In megatron init, args.rank comes from the torchrun env var.
# Once init has been done, args.rank is updated to value of torch get_rank()
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return int(os.getenv("RANK", "0"))
[docs]
def get_world_size_safe() -> int:
"""
Get the distributed world size safely, even if torch.distributed is not initialized.
Returns:
The total number of processes in the distributed job.
"""
# In megatron init, args.world_size comes from the torchrun env var.
# Once init has been done, args.world_size is updated to value of torch get_world_size()
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return int(os.getenv("WORLD_SIZE", "1"))
[docs]
def get_local_rank_preinit() -> int:
"""
Get the local rank from the environment variable, intended for use before full init.
Returns:
The local rank of the current process.
"""
return int(os.getenv("LOCAL_RANK", "0"))
[docs]
def append_to_progress_log(save_dir: str, string: str, barrier: bool = True) -> None:
"""
Append a formatted string to the progress log file (rank 0 only).
Includes timestamp, job ID, and number of GPUs in the log entry.
Args:
save_dir: The directory where the 'progress.txt' file is located.
string: The message string to append.
barrier: If True, performs a distributed barrier before writing (rank 0 only).
"""
if save_dir is None:
return
progress_log_filename = os.path.join(save_dir, "progress.txt")
if barrier and torch.distributed.is_initialized():
torch.distributed.barrier()
if get_rank_safe() == 0:
os.makedirs(os.path.dirname(progress_log_filename), exist_ok=True)
with open(progress_log_filename, "a+") as f:
job_id = os.getenv("SLURM_JOB_ID", "")
num_gpus = get_world_size_safe()
f.write(
f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\tJob ID: {job_id}\t# GPUs: {num_gpus}\t{string}\n"
)
[docs]
def barrier_and_log(string: str) -> None:
"""
Perform a distributed barrier and then log a message on rank 0.
Args:
string: The message string to log.
"""
if torch.distributed.is_initialized():
torch.distributed.barrier()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info("[{}] datetime: {} ".format(string, time_str))
[docs]
def reduce_loss(
loss_store: list[torch.Tensor],
total_num_tokens: torch.Tensor,
per_token_loss: bool = True,
dp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Reduce loss across all ranks.
Args:
loss_store: List of loss tensors to reduce.
total_num_tokens: Total number of tokens to divide the loss by.
per_token_loss: Whether to divide the loss by the number of tokens.
dp_group: Process group to reduce the loss across.
Returns:
Tuple of reduced loss and denominator.
"""
loss = torch.sum(torch.stack(loss_store).float()).view(1).clone().detach()
if dp_group is not None:
dist.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=dp_group)
if per_token_loss:
denominator = total_num_tokens.clone().detach().to(torch.int)
else:
denominator = torch.tensor([len(loss_store)], dtype=torch.int, device="cuda")
if dp_group is not None:
dist.all_reduce(denominator, op=torch.distributed.ReduceOp.SUM, group=dp_group)
return loss, denominator
[docs]
def get_sync_ctx(model, is_optim_step):
"""
Get the synchronization context for the model.
Args:
model: The model to synchronize.
is_optim_step: Whether the current step is an optimizer step.
Returns:
A context manager that synchronizes the model.
"""
# Use `no_sync` on DDP models when we are *not* on the final micro-batch for
# this gradient update (i.e., when `is_grad` is False). This avoids an
# all-reduce for every micro-batch and greatly improves throughput.
if isinstance(model, dist.fsdp._fully_shard._fully_shard.FSDPModule):
model.set_requires_gradient_sync(is_optim_step)
sync_ctx = nullcontext()
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
if is_optim_step:
sync_ctx = nullcontext()
else:
sync_ctx = model.no_sync()
else:
sync_ctx = nullcontext()
return sync_ctx
[docs]
@torch.no_grad()
def rescale_gradients(model, num_tokens_for_grad_scaling, dp_group=None):
"""
Rescale gradients across the DP group.
Args:
model: The model to rescale.
num_tokens_for_grad_scaling: The number of tokens to divide the gradients by.
dp_group: The process group to rescale the gradients across.
"""
num_tokens_for_grad_scaling = num_tokens_for_grad_scaling.clone().detach()
dp_group_size = 1
if dp_group is not None:
dist.all_reduce(num_tokens_for_grad_scaling, group=dp_group)
dp_group_size = dist.get_world_size(group=dp_group)
# DDP/FSDP reduces gradients across ranks, so we need to scale by the world size to inverse it
scaling_factor = dp_group_size / num_tokens_for_grad_scaling
for param in model.parameters():
if param.grad is not None:
param.grad.data.mul_(scaling_factor)
# based on: https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/utils.py#L278
[docs]
@torch.no_grad()
def clip_gradients(model, clip_norm, foreach=True):
"""
Clip gradients across the DP group.
Args:
model: The model to clip the gradients of.
clip_norm: The maximum norm of the gradients.
foreach: if enabled will use fused operations.
"""
grads = [p.grad for p in model.parameters() if p.grad is not None]
grad_norm = torch.nn.utils.get_total_norm(grads, foreach=foreach)
if isinstance(grad_norm, torch.distributed.tensor.DTensor):
grad_norm = grad_norm.full_tensor()
torch.nn.utils.clip_grads_with_norm_([p for p in model.parameters()], clip_norm, grad_norm, foreach=foreach)
return grad_norm