core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer#

Module Contents#

Classes#

MultiGroupUBRAllocator

A custom allocator class that registers a single memory pool with multiple different communication groups, which is not natively supported by apex’s nccl_allocator.

BucketingPolicy

A policy for bucketing in Fully Sharded Data Parallel (FSDP) training.

Bucket

A container for holding data in Fully Sharded Data Parallel (FSDP) training.

TemporaryBucketAllocator

A utility class for managing temporary buckets (buffers) used in FSDP operations like parameters unshard and gradients reduction.

StorageResizeBasedBucketAllocator

A specialized temporary bucket allocator that resizes the storage of temporary buckets based on the required size.

RotaryBucketAllocator

A specialized temporary bucket allocator that implements a circular buffer recycling strategy to minimize memory fragmentation in FSDP operations.

FixedPoolAllocator

A specialized temporary bucket allocator that implements a buffer recycling strategy to minimize memory fragmentation in FSDP operations.

DataParallelBuffer

A class that manages the data parallel buffer for Fully Sharded Data Parallel (FSDP) training. It has two operating modes given a bucket of module parameters:

ParameterGroup

Represents a group of model parameters along with metadata for managing data-parallel training in PyTorch.

ParamAndGradBuffer

A class that manages parameter grouping, buffer allocation, and communication operations for data-parallel distributed training.

BucketStatus

An enumeration of possible statuses for a data-parallel communication bucket.

GradReducePipeline

Pipeline for reducing gradients.

PrefetchOrder

An enumeration of possible prefetch orders for data-parallel operations.

AllGatherPipeline

Pipeline for all-gathering parameters.

ResetParametersContext

Context manager for resetting parameters for meta device initialization module.

Functions#

_p_assert

Alternate to assert when in the backward context to print the error message s since otherwise, it is swallowed.

_alloc_storage

Allocate storage for tensor with the given size.

_free_storage

Frees the underlying storage of tensor.

_pad

build_data_parallel_buffer_index

Assuming that all input tensor elements contiguously compose a global buffer, give the index range of every tensor, the bucket in the buffer, and the (distributed) shard within the bucket. Note that the global bucket buffer is only temporarily allocated, but is abstractly tracked via indices deduced from the number of raw parameters assigned to this buffer / bucket.

_get_dp_buffer_shard_bucket_index

Build the data parallel buffer shard bucket index from the bucket index.

_get_parameter_groups

Get the parameter group for the given module and parameters.

gradient_reduce_preprocessing

Gradient reduce preprocessing for gradient averaging and gradient scaling.

check_gpu_memory

Check if the GPU memory is over the threshold.

override_sharded_param_methods_with_safety_checks

Override the methods of the parameters to prevent undefined behavior.

_dtype_size

Get the size of the dtype.

to_local_if_dtensor

Convert a DTensor to a local tensor.

_get_fsdp_tensor_spec

Get the DeviceMesh for the parameter and modify the placement for Megatron-FSDP.

make_fsdp_dtensor

Creates a distributed tensor (DTensor) from a local tensor with support for Megatron-FSDP and Tensor Parallel scenarios.

Data#

API#

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.logger#

‘getLogger(…)’

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.NCCL_ALLOCATOR#

None

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.NCCL_MEMORY_POOL#

None

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer._p_assert(
cond: Any,
s: str,
raise_assertion_error: bool = True,
) None#

Alternate to assert when in the backward context to print the error message s since otherwise, it is swallowed.

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer._alloc_storage(tensor: torch.Tensor, size: torch.Size) None#

Allocate storage for tensor with the given size.

Returns:

True if this method allocated storage and False if the storage was already allocated.

Return type:

bool

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer._free_storage(tensor: torch.Tensor)#

Frees the underlying storage of tensor.

Returns:

True if the method freed the storage and False if the storage was already freed.

Return type:

bool

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.TensorItemIndex#

‘namedtuple(…)’

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.BucketIndex#

‘namedtuple(…)’

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.ShardBucketIndex#

‘namedtuple(…)’

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.MultiGroupUBRAllocator(pool, groups)#

A custom allocator class that registers a single memory pool with multiple different communication groups, which is not natively supported by apex’s nccl_allocator.

This is particularly useful for Mixture of Experts (MoE) models where:

  • Non-expert parameters/gradients use the data-parallel + context-parallel group (dp_cp_group)

  • Expert parameters/gradients use the expert-parallel + data-parallel group (ep_dp_group)

Since Megatron-Core FSDP uses a contiguous single tensor for the entire model’s parameters, we need to register the same memory pool with both communication groups to enable nccl algorithms that is relying on the user buffer registration for both expert and non-expert parameters.

Implementation: It uses apex nccl_allocator internally to create a Tensor using ncclMemAlloc and register to the group and then registers the Mempool also for the additional_group

.. rubric:: Example

import apex.contrib.nccl_allocator as nccl_allocator
nccl_allocator.init()
pool = nccl_allocator.create_nccl_mem_pool()
group_1 = torch.distributed.new_group(ranks=[0, 1, 2, 3, 4, 5, 6, 7], backend="nccl")
group_2 = torch.distributed.new_group(ranks=[0, 2, 4, 6], backend="nccl")
with MultiGroupUBRAllocator(pool, groups=[group_1, group_2]):
    a = torch.zeros(1024, dtype=torch.float32, device="cuda")
    b = torch.zeros(1024, dtype=torch.float32, device="cuda")

Initialization

__enter__()#
__exit__(*args)#
class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.BucketingPolicy#

A policy for bucketing in Fully Sharded Data Parallel (FSDP) training.

.. attribute:: suggested_bucket_size

The suggested size of each bucket in num of elements.

Type:

int

.. attribute:: fsdp_unit_modules

A list of module classes that are treated as a single unit for FSDP bucketing.

Type:

list

.. attribute:: data_parallel_sharding_strategy

The strategy used for sharding data parallel modules.

Type:

str

.. note:: This policy is used to configure the bucketing behavior in FSDP training.

suggested_bucket_size: Optional[int]#

40000000

fsdp_unit_modules: List[torch.nn.Module]#

‘field(…)’

data_parallel_sharding_strategy: str#

‘no_shard’

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer._pad(number_to_be_padded: int, divisor: int) int#
core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.build_data_parallel_buffer_index(
elements: List[torch.Size],
data_parallel_rank: int,
data_parallel_world_size: int,
is_data_distributed: bool,
ddp_config: megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
bucket_id: int = 0,
chunk_size_factor: int = 1,
) Tuple[List[tuple], core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.BucketIndex, core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.ShardBucketIndex]#

Assuming that all input tensor elements contiguously compose a global buffer, give the index range of every tensor, the bucket in the buffer, and the (distributed) shard within the bucket. Note that the global bucket buffer is only temporarily allocated, but is abstractly tracked via indices deduced from the number of raw parameters assigned to this buffer / bucket.

Parameters:
  • elements (List[torch.Size]) – List of input tensor.

  • data_parallel_rank (int) – Rank of the current process in the data parallel group.

  • data_parallel_world_size (int) – World size of the data parallel group.

  • bucket_id (int, optional) – The id of the bucket. Defaults to 0.

Returns:

The index range of every tensor, every bucket and every in bucket local buffer.

Return type:

Tuple[Dict[int, TensorItemIndex], BucketIndex, ShardBucketIndex]

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer._get_dp_buffer_shard_bucket_index(
bucket_index: core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.BucketIndex,
is_data_distributed: bool,
data_parallel_world_size: int,
data_parallel_rank: int,
) core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.ShardBucketIndex#

Build the data parallel buffer shard bucket index from the bucket index.

Parameters:
  • bucket_index (BucketIndex) – The bucket index containing information on the items in the bucket.

  • is_data_distributed (bool) – Whether the data is distributed across multiple processes.

  • data_parallel_world_size (int) – The world size of the data parallel group.

  • data_parallel_rank (int) – The rank of the current process in the data parallel group.

Returns:

The shard bucket index containing information on the location and size of the buffer shard in the global bucket.

Return type:

ShardBucketIndex

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.Bucket#

A container for holding data in Fully Sharded Data Parallel (FSDP) training.

.. attribute:: data

A tensor containing the data elements grouped together in a bucket. used to synchronize data operations.

Type:

torch.Tensor

.. note::

Buckets are used to optimize communication in FSDP training by grouping small tensors together.

data: torch.Tensor#

None

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.TemporaryBucketAllocator#

A utility class for managing temporary buckets (buffers) used in FSDP operations like parameters unshard and gradients reduction.

This allocator handles the dynamic allocation and deallocation of temporary memory buffers needed during FSDP (Fully Sharded Data Parallel) operations, particularly for parameters unshard and gradients reduction. It helps optimize memory usage by allowing temporary buckets to be released when no longer needed.

Key Features: - Dynamic allocation of temporary buckets for FSDP operations - Memory-efficient management of temporary buffers - Support for both parameters unshard and gradients reduction operations - Automatic cleanup of unused buckets to save memory

Usage: ```python # Create an allocator instance allocator = TemporaryBucketAllocator(name=”gpt_parameters”)

# Allocate a temporary bucket
temp_bucket = allocator.allocate(size=1024, dtype=torch.float32)

# Use the temporary bucket for FSDP operations
# ... perform all-gather or reduce-scatter ...

# Free the bucket when done
allocator.free(temp_bucket)
```

.. note::

It’s important to release temporary buckets after use to prevent memory leaks and optimize memory usage during training.

Initialization

allocate(
bucket_id: int,
size: int,
dtype: torch.dtype,
device: torch.device,
mem_alloc_context: Optional[Callable] = None,
) core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.Bucket#

allocate a temporary bucket.

free(bucket_id: int)#

free a temporary bucket.

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.StorageResizeBasedBucketAllocator#

Bases: core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.TemporaryBucketAllocator

A specialized temporary bucket allocator that resizes the storage of temporary buckets based on the required size.

Initialization

allocate(
bucket_id: int,
size: int,
dtype: torch.dtype,
device: torch.device,
mem_alloc_context: Optional[Callable] = None,
) core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.Bucket#

allocate a temporary bucket.

free(bucket_id: int)#

free a temporary bucket.

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.RotaryBucketAllocator(name: str)#

Bases: core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.TemporaryBucketAllocator

A specialized temporary bucket allocator that implements a circular buffer recycling strategy to minimize memory fragmentation in FSDP operations.

RotaryBucketAllocator extends TemporaryBucketAllocator by maintaining a limited pool of pre-allocated buffers that are reused in a circular manner. This approach helps prevent memory fragmentation that typically occurs with frequent allocation and deallocation of temporary buffers during FSDP operations.

Key Features: - Circular buffer recycling strategy for memory efficiency - Reduced memory fragmentation compared to dynamic allocation - Pre-allocated buffer pool for faster access - Automatic buffer reuse without explicit deallocation

Usage: ```python # Create a rotary allocator allocator = RotaryBucketAllocator(name=”gpt_parameters”)

# Get a temporary buffer from the pool
temp_bucket = allocator.allocate(dtype=torch.float32)

# Use the temporary bucket for FSDP operations
# ... perform all-gather or reduce-scatter ...

# Free the bucket when done, make it in idle buffer pool
allocator.free(temp_bucket)
```

Initialization

allocate(
bucket_id: int,
size: int,
dtype: torch.dtype,
device: torch.device,
mem_alloc_context: Optional[Callable] = None,
) core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.Bucket#

allocate a temporary bucket.

_get_gbuf_name(buffer_id: int)#
free(bucket_id: int)#

free a temporary bucket.

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.FixedPoolAllocator(
name: str,
fsdp_param_groups: List[core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.ParameterGroup],
size: int = 2,
fallback_to_persistent_buffer: bool = False,
)#

Bases: core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.TemporaryBucketAllocator

A specialized temporary bucket allocator that implements a buffer recycling strategy to minimize memory fragmentation in FSDP operations.

This allocator maintains a fixed pool of pre-allocated buffers, reusing them to reduce the overhead and fragmentation caused by frequent allocation and deallocation of temporary buffers during FSDP operations.

Initialization

_is_two_bucket_group_equal(group_a, group_b)#
allocate(
bucket_id: int,
size: int,
dtype: torch.dtype,
device: torch.device,
mem_alloc_context: Optional[Callable] = None,
) core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.Bucket#

allocate a temporary bucket.

_get_gbuf_name(buf_group_id: int, bucket_index: int)#
free(bucket_id: int)#

free a temporary bucket.

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.DataParallelBuffer(
ddp_config: megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
params: List[torch.nn.Parameter],
is_data_distributed: bool,
bucket_id: int,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
dp_rank: Optional[int] = None,
temporary_bucket_allocator: Optional[core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.TemporaryBucketAllocator] = None,
is_transpose_buffer: bool = False,
gradient_scaling_factor: Optional[float] = None,
chunk_size_factor: int = 1,
mem_alloc_context: Optional[Callable] = None,
item_index_map: Optional[Dict[int, core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.TensorItemIndex]] = None,
bucket_index: Optional[core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.BucketIndex] = None,
shard_bucket_index: Optional[core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.ShardBucketIndex] = None,
)#

A class that manages the data parallel buffer for Fully Sharded Data Parallel (FSDP) training. It has two operating modes given a bucket of module parameters:

- Sharded: The bucket is sharded across the data parallel group, and each
    rank will manage a shard of the bucket that is persistently stored in this buffer.
- Unsharded: The bucket is not sharded, and the entire bucket is persistently
    stored in this buffer. Virtual shards of this unsharded buffer can be
    retrieved from each rank when needed.

This design supports interoperability of sharded and unsharded buffers, e.g. optim and optim_grads, where buffers associated with sharded parameters can be utilized with buffers associated with unsharded parameters through the use of “virtual” or rank-specific shards for the unsharded buffers.

Initialization

init_data(data: torch.Tensor)#

Allocate a buffer Tensor to persistently store the data for this (shard of) the buffer.

fetch_bucket(
dtype: Optional[torch.dtype] = None,
set_param_data: bool = False,
) core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.Bucket#

Fetch a communication buffer for data-parallel operations.

The size of the bucket is defined by the DataParallelBuffer instance.

Parameters:

dtype (Optional[torch.dtype], optional) – The data type of the tensor to fetch a buffer for. Defaults to None.

Returns:

The communication buffer for the specified data type.

Return type:

Bucket

free_bucket_storage()#

Release the storage of a temporary communication bucket. If the bucket is temporary, this method frees its storage.

_get_item_slice_in_shard(item_id: int) Tuple[int, int]#

Return the coordinates of the slice of the item that is contained in this buffer shard. In other words, this returns the coordinates of all of the data in this item that is stored in this shard.

Maps to the global coordinates of the item in the bucket when added to the starting coordinate of the item in the bucket, and maps to the local coordinates of the item in the shard when added to the difference between the starting coordinate of the item and the starting coordinate of the shard in the global bucket (i.e. mapping from item coordinates to global coordinates to shard coordinates).

locate_item_in_global_item(item_id: int) Tuple[int, int]#

Return the coordinates of the slice of the item that is contained in this buffer shard. In other words, this returns the coordinates of all of the data in this item that is stored in this shard.

Helper function that adds a shortcut when the buffer is not sharded, in which case we don’t need to compute the item-shard intersection, and can simply return the coordinates of the entire item.

_get_item_local_shard_index(item_id: int) Tuple[int, int]#

Return the local coordinates of the slice of this buffer’s shard that contains the item with the given ID. In other words, this returns the coordinates of all of the data in this shard associated with the item.

Maps to the global coordinates of the item in the bucket when added to the starting coordinate of the shard in the global bucket, and maps to the coordinates of the item contained in the shard when added to the difference between the starting coordinate of the shard and the starting coordinate of the item in the global bucket (i.e. mapping from shard coordinates to global coordinates to item coordinates).

_get_item_local_index(item_id: int) Tuple[int, int]#

Return the local coordinates of the slice of this buffer’s data that contains the item with the given ID.

set_item(item_id: int, item_data: torch.Tensor) None#

Update a Tensor item managed by the DataParallelBuffer instance, i.e. store (a shard of) the Tensor in this buffer’s datastore.

The storage of the item is mapped to the communication bucket. This method updates the item data and ensures consistency with the bucket.

Parameters:
  • item_id (int) – The ID of the tensor item to update.

  • item_data (torch.Tensor) – The new data for the tensor item.

Returns:

None

get_item(item_id: int, only_shard: bool = False) torch.Tensor#

Retrieve a tensor item managed by the DataParallelBuffer instance, i.e. get all the item data stored in this sharded or unsharded buffer.

The storage of the item is mapped to the communication bucket. If only_shard is True, returns only the shard of the item corresponding to the current process / rank, a “virtual shard” for unsharded buffers. Otherwise, returns the entire item, which could be a bucket shard or bucket.

Parameters:
  • item_id (int) – The ID of the tensor item to retrieve.

  • only_shard (bool, optional) – Whether to return only the shard of the item. Defaults to False.

Returns:

The retrieved tensor item.

Return type:

torch.Tensor

get_item_from_bucket(
bucket: core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.Bucket,
item_id: int,
)#

Get Tensor item data from the given bucket specified by the item ID.

get_shard_from_bucket(
bucket: core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.Bucket,
)#

Get the shard from the provided bucket associated with the sharding strategy of this buffer.

get_shard_from_local_buffer() torch.Tensor#

Get the shard or virtual shard of the bucket stored in this buffer.

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.ParameterGroup#

Represents a group of model parameters along with metadata for managing data-parallel training in PyTorch.

This class encapsulates a list of parameters and associated information such as data type, gradient requirements, and references to buffers used in distributed training contexts.

.. attribute:: params

The list of model parameters grouped together.

Type:

List[torch.nn.Parameter]

.. attribute:: dtype

The desired data type for the parameters.

Type:

Optional[torch.dtype]

.. attribute:: is_expert_param

Indicates if this group contains expert parameters (e.g., in mixture-of-experts).

Type:

bool

.. attribute:: requires_grad

Specifies if gradients should be computed for these parameters.

Type:

Optional[bool]

.. attribute:: fsdp_unit_id

Identifier for Fully Sharded Data Parallel (FSDP) unit grouping.

Type:

Optional[int]

.. attribute:: chunk_size_factor

Factor determining chunk size for grouped parameter processing.

Type:

int

.. attribute:: model_weight_buffer

Buffer used to store model weights for data-parallel operations.

Type:

Optional[DataParallelBuffer]

.. attribute:: transpose_weight_buffer

Buffer used to store transpose weights for data-parallel operations.

Type:

Optional[DataParallelBuffer]

.. attribute:: main_weight_buffer

Buffer used to store main model weights for data-parallel operations.

Type:

Optional[DataParallelBuffer]

.. attribute:: main_grad_buffer

Buffer used to store main gradients for data-parallel operations.

Type:

Optional[DataParallelBuffer]

.. attribute:: hsdp_wbuf

Buffer for weights used in Hybrid Sharded Data Parallel (HSDP). Exists only if full sharding is enabled in HSDP.

Type:

Optional[DataParallelBuffer]

.. attribute:: hsdp_gbuf

Buffer for gradients used in HSDP. Exists only if full sharding is enabled in HSDP.

Type:

Optional[DataParallelBuffer]

params: List[torch.nn.Parameter]#

None

dtype: Optional[torch.dtype]#

None

is_expert_param: bool#

False

requires_grad: Optional[bool]#

None

fsdp_unit_id: Optional[int]#

None

chunk_size_factor: int#

1

model_weight_buffer: Optional[core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.DataParallelBuffer]#

None

transpose_weight_buffer: Optional[core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.DataParallelBuffer]#

None

main_weight_buffer: Optional[core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.DataParallelBuffer]#

None

main_grad_buffer: Optional[core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.DataParallelBuffer]#

None

hsdp_wbuf: Optional[core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.DataParallelBuffer]#

None

hsdp_gbuf: Optional[core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.DataParallelBuffer]#

None

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer._get_parameter_groups(
module: torch.nn.Module,
policy: core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.BucketingPolicy,
meta_device_init_fp8_params: dict,
bucket_group_by_fsdp_unit: bool = True,
)#

Get the parameter group for the given module and parameters.

Parameters:
  • module (torch.nn.Module) – The module whose parameters are to be grouped and flattened.

  • policy (BucketingPolicy) – The bucketing policy.

  • meta_device_init_fp8_params (dict) – A dictionary mapping parameter names to a boolean indicating whether the parameter is initialized on the meta device.

  • bucket_group_by_fsdp_unit (bool) – Whether to group buckets by FSDP unit.

Returns:

  • The list of parameter groups.

  • The mapping from parameters to their bucket group ID.

  • The mapping from bucket ID to the full group of bucket IDs that are NCCL-aggregated with this bucket ID.

Return type:

Tuple[List[ParameterGroup], Dict[torch.nn.Parameter, int], Dict[int, List[int]]]

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.ParamAndGradBuffer(
ddp_config: megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
module: torch.nn.Module,
bucketing_policy: core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.BucketingPolicy,
dist_index: core.distributed.fsdp.src.megatron_fsdp.utils.FSDPDistributedIndex,
preserve_fp32_weights: bool = True,
grad_reduce_in_fp32: bool = True,
gradient_scaling_factor: Optional[float] = None,
expert_gradient_scaling_factor: Optional[float] = None,
device: torch.device = torch.device('cuda'),
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad: bool = True,
reset_parameters_for_meta_device_init_module: bool = False,
)#

A class that manages parameter grouping, buffer allocation, and communication operations for data-parallel distributed training.

This class provides functionality to:

  1. Group parameters based on their data types and communication group sizes.

  2. Create contiguous buffers for model weights, gradients, and high-precision main weights.

  3. Handle parameter unsharding, gradient reduction, and weight synchronization operations.

Key Features: - Efficient parameter grouping based on data types and communication patterns - Memory-efficient contiguous buffer allocation - Support for mixed-precision training with main weights - Distributed operations including parameters all-gather and gradients reduce-scatter/all-reduce - Synchronized weight updates between model and main weights

.. note::

This class is designed for distributed training scenarios where efficient parameter management and communication are crucial for performance.

Parameters:
  • ddp_config (DistributedDataParallelConfig) – The distributed data parallel configuration.

  • module (torch.nn.Module) – The module whose parameters are to be grouped and flatten.

  • bucketing_policy (BucketingPolicy) – The bucketing policy.

  • data_parallel_group (torch.distributed.ProcessGroup) – The data parallel group.

  • expert_data_parallel_group (Optional[torch.distributed.ProcessGroup]) – The expert data parallel group.

  • preserve_fp32_weights (bool) – Whether to preserve FP32 weights.

  • grad_reduce_in_fp32 (bool) – Whether to reduce gradients in FP32.

  • gradient_scaling_factor (Optional[float]) – The gradient scaling factor.

  • expert_gradient_scaling_factor (Optional[float]) – The expert gradient scaling factor.

  • device (torch.device) – The parameter and gradient buffer device.

  • only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad (bool) – Whether to only create the gradient buffer and main weight buffer for parameters that require gradients. Default is True.

Initialization

get_mem_alloc_context(groups=None, symmetric=True)#

Get the memory allocation context for the parameter and gradient buffers.

manual_buffer_registration()#

Manually register the FSDP communication buffers to NCCL user buffer.

_log_parameter_groups()#

Compact log of FSDP parameter groups and their parameters.

_init_each_parameter_group_buffers(meta_device_init_fp8_params)#

Initialize the buffers for each parameter group.

_reset_parameters(old_params, new_params)#
scale_gradients(scaling_factor: float) None#

Scale the gradient data by scaling_factor.

zero_grad()#

Zero out the underlying grad_buffer and reset all buckets in preparation for the next iteration of training.

_init_distributed_params()#

Register model training and high-precision parameters as optimizer named parameters and DTensor(s). Specifically, we utilize the highest precision weights available for optimization using fall-back logic on mbuf -> wbuf -> orig_param depending on if preserve_fp32_weights or “no_shard” is utilized.

_init_optimizer_named_parameters() List[Tuple[str, torch.nn.Parameter]]#
update_main_grads()#

Update the gradients in the model parameters with the main gradients from the main gradient buffer. If the model parameters are sharded, we only need to update the gradient shard associated with the model parameter shard, as both are sharded symmetrically.

Checks if high-precision main weights are utilized for optimization. Otherwise, falls back to low-precision model weights, and further falls back to the original module parameters not managed by cFSDP in the case of no sharding / cFSDP OFF.

property num_buckets#

Return the number of buckets.

copy_main_weights_to_model_weights()#

Update the model weights from the main weights.

If FP8 parameters are utilized, this function will quantize the high-precision main weights prior to installation into the model compute weight buffers.

copy_model_weights_to_main_weights()#

Copy the model weights to the main weights.

all_gather_parameters(async_op: bool = True)#

All gather the parameters.

Parameters:

async_op (bool, optional) – Whether to do the all-reduce asynchronously. Defaults to False.

reduce_scatter_gradients(async_op: bool = True)#

Reduce scatter the gradients.

Parameters:

async_op (bool, optional) – Whether to do the all-reduce asynchronously. Defaults to False.

all_reduce_gradients(async_op: bool = False)#

All reduce the gradients.

Parameters:

async_op (bool, optional) – Whether to do the all-reduce asynchronously. Defaults to False.

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.BucketStatus(*args, **kwds)#

Bases: enum.Enum

An enumeration of possible statuses for a data-parallel communication bucket.

.. attribute:: EMPTY

The bucket is empty and not in use.

Type:

int

.. attribute:: COMMUNICATING

The bucket is currently being used for communication.

Type:

int

.. attribute:: READY_TO_USE

The bucket is filled with data and ready for use.

Type:

int

Initialization

EMPTY#

1

COMMUNICATING#

2

READY_TO_USE#

3

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.GradReducePipeline(
param_and_grad_buffer: core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.ParamAndGradBuffer,
rs_stream: Optional[torch.cuda.Stream] = None,
check_nans: bool = False,
)#

Pipeline for reducing gradients.

Initialization

property num_buckets#

Return the number of buckets.

reset()#

Handle the processing tasks and reset the pipeline.

reduce_gradients(
params: List[torch.Tensor],
suggested_queue_capacity: Optional[int] = None,
outer_fsdp_group_grad_reduce: bool = False,
)#

Reduce the gradients for the given parameters.

Parameters:
  • params (List[torch.Tensor]) – The parameters.

  • suggested_queue_capacity (int, optional) – The suggested queue capacity. Defaults to None.

  • outer_fsdp_group_grad_reduce (bool, optional) – Whether to reduce gradients across outer-DP groups. Defaults to False.

wait_for_previous_grad_reduce(
suggested_queue_size: int = 1,
suggested_queue_capacity: Optional[int] = None,
)#

Wait for the previous reduce-scatter/all-reduce to finish.

Parameters:
  • suggested_queue_size (int, optional) – The recommended queue size in buckets. Defaults to 1.

  • suggested_queue_capacity (Optional[int], optional) – The recommended queue capacity in number of parameters in all buckets in the reduction queue. Defaults to None.

_enforce_double_buffer_limit(add_buckets)#
get_ready_bucket_group_for_reduction(
bucket_id: int,
) Optional[List[int]]#

Checks if all buckets in the bucket group containing the given bucket_id are ready for gradient reduction. If so, returns the list of ready bucket IDs for reduction; otherwise, returns None.

Parameters:

bucket_id (int) – The bucket to mark as ready for reduce-scatter or all-reduce.

Returns:

The bucket group ready for gradient reduction, or None if not all buckets are ready.

Return type:

Optional[List[int]]

get_fsdp_buffer(
bucket_id: int,
) core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.DataParallelBuffer#

Get the FSDP buffer for the given bucket ID.

_bucket_group_gradient_reduce(
bucket_group: List[int],
async_op: bool = False,
outer_fsdp_group_grad_reduce: bool = False,
) bool#

Mark the bucket ready for reduce-scatter/all-reduce, if all bucket in the bucket group are ready, then do the reduce-scatter/all-reduce.

Parameters:
  • bucket_id (int) – The bucket to be marked.

  • bucket_group (List[int]) – The bucket group to be reduced.

  • async_op (bool, optional) – Whether to do the reduce-scatter/all-reduce asynchronously. Defaults to False.

Returns:

True if the bucket is go for reduce-scatter/all-reduce.

Return type:

bool

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.PrefetchOrder(*args, **kwds)#

Bases: enum.Enum

An enumeration of possible prefetch orders for data-parallel operations.

.. attribute:: FORWARD_PASS_ORDER

Prefetch in the order of forward pass computation.

Type:

int

.. attribute:: BACKWARD_PASS_ORDER

Prefetch in the order of backward pass computation.

Type:

int

Initialization

FORWARD_PASS_ORDER#

0

BACKWARD_PASS_ORDER#

1

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.AllGatherPipeline(
param_and_grad_buffer: core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.ParamAndGradBuffer,
ag_stream: Optional[torch.cuda.Stream] = None,
)#

Pipeline for all-gathering parameters.

Initialization

get_bucket_key(bucket_id, bwd)#

Get the key for the bucket.

property num_buckets#

Return the number of buckets.

reset()#

Reset the pipeline state.

all_gather_params(
params: List[torch.Tensor],
prefetch: bool = False,
prefetch_order: core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.PrefetchOrder = PrefetchOrder.FORWARD_PASS_ORDER,
suggested_AG_prefetch_size: Optional[int] = None,
async_param_gather: bool = True,
outer_fsdp_group_param_gather: bool = False,
bwd: bool = False,
)#

All-gather the params. If prefetch is enabled, prefetch next buckets in the order of prefetch_order.

Parameters:
  • params (List[torch.Tensor]) – The list of params to be all-gathered.

  • prefetch (bool, optional) – Whether to prefetch the next bucket. Defaults to False.

  • prefetch_order (PrefetchOrder, optional) – The order of prefetching. Defaults to PrefetchOrder.FORWARD_PASS_ORDER.

  • suggested_AG_prefetch_size (Optional[int], optional) – The suggested prefetch size for all-gathering. Defaults to None.

  • outer_fsdp_group_param_gather (bool, optional) – Whether to all-gather parameters across outer-DP groups. Defaults to False.

wait_bucket_ready(bucket_id, bwd, empty_ok=False)#

Wait for the bucket to be ready.

release_bucket(bucket_id, bwd, lazy: bool = False)#

Release the specified parameter bucket, freeing its associated buffer storage.

This function marks or frees the memory of a parameter bucket depending on whether lazy release is enabled. It ensures that buckets are not released while still being communicated or in use by the pipeline.

Parameters:
  • bucket_id (int) – Identifier of the bucket to be released.

  • bwd (bool) – Indicates if the release is triggered during the backward pass.

  • lazy (bool, optional) –

    Determines when the parameter buffer (bucket) is released.

    • If False, the buffer is released immediately.

    • If True, the release is deferred until just before the all-gather pipeline requests a new buffer. The delayed release is performed by invoking recycle_unused_buckets.

Raises:

ValueError – If the specified bucket is currently in communication and cannot be safely released.

.. rubric:: Notes

  • Buckets marked as lazy will be released later when the pipeline determines they are no longer needed.

  • If the bucket has a transpose weight buffer (used in FP8 backward passes), this buffer is freed; otherwise, the model weight buffer is released.

recycle_unused_buckets()#

Recycle the unused buckets.

get_fsdp_buffer(
bucket_id: int,
bwd=False,
) core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.DataParallelBuffer#

Get the FSDP buffer with the given bucket ID.

async_bucket_gather(bucket_id, bwd) None#

All-gather the bucket and set the items.

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.gradient_reduce_preprocessing(grad_data, scaling_factor, ddp_config)#

Gradient reduce preprocessing for gradient averaging and gradient scaling.

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.check_gpu_memory(threshold=0.9)#

Check if the GPU memory is over the threshold.

Parameters:

threshold (float, optional) – The threshold to check if the GPU memory is over. Defaults to 0.9.

Returns:

True if the GPU memory is over the threshold.

Return type:

bool

class core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.ResetParametersContext(
init_param_with_fp8=False,
with_cuda_rng_tracker=False,
)#

Context manager for resetting parameters for meta device initialization module.

Initialization

__enter__()#
__exit__(*exc_details)#
core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.override_sharded_param_methods_with_safety_checks(
params,
all_gather_pipeline,
)#

Override the methods of the parameters to prevent undefined behavior.

Parameters:
  • params (List[torch.Tensor]) – The parameters to add hint on shard to functions.

  • all_gather_pipeline (AllGatherPipeline) – The all-gather pipeline.

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer._dtype_size(dtype: torch.dtype) int#

Get the size of the dtype.

Parameters:

dtype (torch.dtype) – The dtype to get the size of.

Returns:

The size of the dtype.

Return type:

int

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.to_local_if_dtensor(tensor)#

Convert a DTensor to a local tensor.

Parameters:

tensor (torch.Tensor) – The tensor to convert.

Returns:

The local tensor.

Return type:

torch.Tensor

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer._get_fsdp_tensor_spec(
param,
dist_index: core.distributed.fsdp.src.megatron_fsdp.utils.FSDPDistributedIndex,
is_sharded_param,
is_expert_param,
)#

Get the DeviceMesh for the parameter and modify the placement for Megatron-FSDP.

core.distributed.fsdp.src.megatron_fsdp.param_and_grad_buffer.make_fsdp_dtensor(
local_tensor: torch.Tensor,
param: torch.nn.Parameter,
dist_index: core.distributed.fsdp.src.megatron_fsdp.utils.FSDPDistributedIndex,
is_sharded_param: bool = True,
is_expert_param: bool = False,
run_check: bool = False,
update_uneven_dtensor_chunk_meta: bool = False,
force_sync_tp_duplicated_param: bool = False,
)#

Creates a distributed tensor (DTensor) from a local tensor with support for Megatron-FSDP and Tensor Parallel scenarios.

This function is typically used in a FSDP setup where tensor data needs to be converted into sharded DTensors across a device mesh. It also supports model configurations involving tensor model parallelism such as Megatron-Core.

Parameters:
  • local_tensor (torch.Tensor) – The local tensor data to be converted to a DTensor.

  • param (nn.Parameter) – Template parameter used to infer shape, stride, and partition attributes.

  • dist_index (FSDPDistributedIndex) – Metadata object providing the distributed device mesh.

  • is_sharded_param (bool, optional) – Whether the parameter is sharded across devices. Defaults to True.

  • is_expert_param (bool, optional) – Indicates if the tensor corresponds to Megatron-Core expert (Mixture-of-Experts) parameters. Defaults to False.

  • run_check (bool, optional) – Enables additional internal validation for DTensor. Defaults to False.

  • update_uneven_dtensor_chunk_meta (bool, optional) – Whether to update metadata for uneven chunk distributions. Defaults to False.

Returns:

A DTensor object sharded appropriately across devices.

Return type:

DTensor

.. rubric:: Example

import torch from torch.distributed.device_mesh import init_device_mesh from torch.distributed._tensor import DeviceMesh from my_fsdp_utils import FSDPDistributedIndex # assumed utility

Initialize device mesh (4 GPUs)

device_mesh = DeviceMesh(“cuda”, (2, 2), dim_names=(“tp”, “dp”)) dist_index = FSDPDistributedIndex( … device_mesh=device_mesh, … dp_mesh_dim_name=”dp”, … tp_mesh_dim_name=”tp” … )

Dummy local tensor and parameter

local_tensor = torch.randn(8, 16, device=”cuda”) param = torch.nn.Parameter(torch.empty(32, 32))

Attach partition metadata for tensor model parallelism

param.tensor_model_parallel = True param.partition_dim = 0 param.partition_stride = 1

Convert to DTensor

dtensor = make_fsdp_dtensor( … local_tensor=local_tensor, … param=param, … dist_index=dist_index, … is_sharded_param=True, … run_check=True … ) print(dtensor) DTensor(sharded(…))

.. note::

  • For tensor model parallel use cases, the param object must either:

    • Be a tensor-parallel (TP) DTensor, or

    • Include all of these attributes: tensor_model_parallel, partition_dim, and partition_stride.