nemo_automodel.components.distributed.ddp

View as Markdown

Module Contents

Classes

NameDescription
DDPManagerManager for distributed training using PyTorch’s DDP.

Data

logger

API

class nemo_automodel.components.distributed.ddp.DDPManager(
config: nemo_automodel.components.distributed.config.DDPConfig
)

Manager for distributed training using PyTorch’s DDP.

This manager wraps models with DistributedDataParallel for data-parallel distributed training.

Parameters:

config
DDPConfig

Configuration for DDP distributed training.

activation_checkpointing
= config.activation_checkpointing
broadcast_buffers
= config.broadcast_buffers
bucket_cap_mb
= config.bucket_cap_mb
find_unused_parameters
= config.find_unused_parameters
gradient_as_bucket_view
= config.gradient_as_bucket_view
static_graph
= config.static_graph
nemo_automodel.components.distributed.ddp.DDPManager._setup_distributed()

Initialize device configuration for DDP.

Sets the rank, world_size, and device based on the process group backend.

nemo_automodel.components.distributed.ddp.DDPManager.parallelize(
model
)

Wraps the given model with DistributedDataParallel (DDP).

Moves the model to the initialized device before wrapping. For CUDA devices, the device id is passed to DDP as device_ids; for CPU, no device ids are provided.

Parameters:

model
torch.nn.Module

The PyTorch model to be wrapped.

Returns:

torch.nn.parallel.DistributedDataParallel: The DDP-wrapped model.

nemo_automodel.components.distributed.ddp.logger = logging.getLogger(__name__)