nemo_automodel.components.distributed.tensor_utils

View as Markdown

Tensor utilities for device transfers and memory management in distributed settings.

This module provides utilities for handling tensor operations across different devices and distributed tensor types, with optimizations for performance in distributed training scenarios.

Module Contents

Functions

NameDescription
get_cpu_state_dictCopy the state dict generator to CPU memory.
to_cpuMove a tensor or distributed tensor to the CPU.
to_local_if_dtensorReturns the local shard of the given tensor if it is a DTensor.

API

nemo_automodel.components.distributed.tensor_utils.get_cpu_state_dict(
state_generator: typing.Iterable[tuple[str, typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]]],
pin_memory: bool = False
) -> dict[str, torch.Tensor]

Copy the state dict generator to CPU memory.

Parameters:

state_generator
Iterable[tuple[str, Union[torch.Tensor, DTensor]]]

An iterable that yields (key, tensor) pairs from a model state.

pin_memory
boolDefaults to False

Whether to allocate the CPU tensors in pinned memory for faster GPU transfer. Defaults to False.

Returns: dict[str, torch.Tensor]

dict[str, torch.Tensor]: A dictionary mapping parameter names to CPU tensors.

nemo_automodel.components.distributed.tensor_utils.to_cpu(
v
)

Move a tensor or distributed tensor to the CPU.

This function takes an input tensor, which can be either a DTensor (distributed tensor) or a standard Tensor, and ensures that it is moved to the CPU.

Parameters:

v
DTensor | Tensor | any

The input value, which can be a DTensor, Tensor, or any other object. If DTensor, it checks the device and moves the tensor accordingly.

Returns:

Tensor | any: The corresponding CPU tensor if v is a DTensor or Tensor, otherwise returns v unchanged.

Raises:

  • ValueError: If v is a DTensor but its device is neither ‘cuda’ nor ‘cpu’.
nemo_automodel.components.distributed.tensor_utils.to_local_if_dtensor(
tensor: typing.Union[torch.Tensor, torch.distributed.tensor.DTensor]
) -> torch.Tensor

Returns the local shard of the given tensor if it is a DTensor.

Taken and modified from: https://github.com/NVIDIA/Megatron-LM/blob/605f618f237cda8fa80132bc2ccff933512d5a0d/megatron/core/utils.py#L746