core.distributed.finalize_model_grads#

Module Contents#

Functions#

_get_main_grad_attr

_unshard_if_dtensor

Unshards the input tensor if it is a DTensor and otherwise returns the tensor unmodified.

_reshard_if_dtensor

Reshards the input tensor to match the sharding configuration of the reference tensor if the reference tensor is a DTensor. Otherwise, returns the reference tensor unmodified.

_allreduce_conditional_embedding_grads

All-reduce conditional embedding grads.

_get_shared_word_embedding_weight

Return the shared word-embedding weight if it is duplicated across stages.

_get_position_embedding_weight

Return the position-embedding weight tensor from the given model module.

_allreduce_word_embedding_grads

All-reduce word-embedding gradients across the first and last PP stages.

_allreduce_embedding_grad

Unified helper to all-reduce embedding parameters across pipeline stages.

_allreduce_position_embedding_grads

All-reduce position_embeddings grad across encoder and decoder stages to ensure that position embeddings parameters stay in sync.

reset_model_temporary_tensors

Reset the temporary tensors of the model.

_update_router_expert_bias

Update the expert bias of the router for a global batch. This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks

_allreduce_non_tensor_model_parallel_grads

All-reduce both layernorm grads (for sequence parallelism) and gradients from modules with average_gradients_across_tp_domain=True across tensor-model-parallel ranks.

finalize_model_grads

All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, embedding grads across first and last pipeline stages (if not tied), scale gradients by num_tokens.

Data#

API#

core.distributed.finalize_model_grads._get_main_grad_attr(param: torch.nn.Parameter)#
core.distributed.finalize_model_grads._unshard_if_dtensor(
tensor: Union[torch.Tensor, torch.distributed._tensor.DTensor],
) torch.Tensor#

Unshards the input tensor if it is a DTensor and otherwise returns the tensor unmodified.

Parameters:

tensor (Union[torch.Tensor, DTensor]) – The tensor to potentially unshard.

Returns:

An unsharded version of the input tensor if it is a DTensor, or the input tensor unmodified if it is not a DTensor.

core.distributed.finalize_model_grads._reshard_if_dtensor(
tensor_to_shard: torch.Tensor,
reference_tensor: Union[torch.Tensor, torch.distributed._tensor.DTensor],
) Union[torch.Tensor, torch.distributed._tensor.DTensor]#

Reshards the input tensor to match the sharding configuration of the reference tensor if the reference tensor is a DTensor. Otherwise, returns the reference tensor unmodified.

Parameters:
  • tensor_to_shard (torch.Tensor) – The tensor to be potentially sharded.

  • reference_tensor (Union[torch.Tensor, DTensor]) – The reference tensor for the sharding configuration.

Returns:

The sharded tensor matching the reference tensor’s configuration, or the reference tensor itself if it is not a DTensor.

Return type:

Union[torch.Tensor, DTensor]

core.distributed.finalize_model_grads._allreduce_conditional_embedding_grads(
model: List[torch.nn.Module],
config: core.transformer.transformer_config.TransformerConfig,
pp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

All-reduce conditional embedding grads.

Reduce grads across all the pp stages to ensure that parameters of the conditional embedders (e.g., timestep embedder, FPS embedder, label embedder) stay in sync. This is for the models with replicated embedders on each PP / VPP rank, like diffusion models.

core.distributed.finalize_model_grads._get_shared_word_embedding_weight(
model_module: torch.nn.Module,
config: core.transformer.transformer_config.TransformerConfig,
) Optional[torch.nn.Parameter]#

Return the shared word-embedding weight if it is duplicated across stages.

Parameters:
  • model_module – The model module from which to extract the word-embedding weight.

  • config – Transformer config.

Returns:

The shared embedding or output weight if available; otherwise None.

core.distributed.finalize_model_grads._get_position_embedding_weight(
model_module: torch.nn.Module,
) torch.nn.Parameter#

Return the position-embedding weight tensor from the given model module.

Parameters:

model_module – The model module that owns the position-embedding parameter.

Returns:

The position-embedding weight tensor.

core.distributed.finalize_model_grads._allreduce_word_embedding_grads(
model: List[torch.nn.Module],
config: core.transformer.transformer_config.TransformerConfig,
embd_group: Optional[torch.distributed.ProcessGroup] = None,
pp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

All-reduce word-embedding gradients across the first and last PP stages.

This ensures that the word_embeddings parameters stay in sync when they are shared between the input and output layers.

Parameters:
  • model – A list containing the pipeline chunks that constitute the model on the current rank (including any virtual pipeline chunks).

  • config – Transformer configuration. Used for edge cases like MTP where embeddings might be shared differently.

  • embd_group – The process group over which to all-reduce the word-embedding gradients. If None, it will be looked up based on the current pipeline model parallel group.

  • pp_group – The pipeline parallel process group used to identify first/last stages. If None, it will be looked up.

core.distributed.finalize_model_grads._allreduce_embedding_grad(
model: List[torch.nn.Module],
embd_group: torch.distributed.ProcessGroup,
pp_group: torch.distributed.ProcessGroup,
weight_getter: Callable[[torch.nn.Module], Optional[torch.nn.Parameter]],
skip_if_none: bool = True,
)#

Unified helper to all-reduce embedding parameters across pipeline stages.

Parameters:
  • model (List[torch.nn.Module]) – A list of model chunks (PP/VPP).

  • embd_group (torch.distributed.ProcessGroup) – The process group over which to reduce.

  • pp_group (torch.distributed.ProcessGroup) – The pipeline parallel process group for first/last stage detection.

  • weight_getter (Callable[[torch.nn.Module], Optional[torch.nn.Parameter]]) – A function that takes the pre-process model chunk and returns the parameter to be reduced (or None if not applicable).

  • skip_if_none (bool, optional) – If True, quietly returns when the parameter or its gradient is None. Defaults to True.

core.distributed.finalize_model_grads._allreduce_position_embedding_grads(
model: List[torch.nn.Module],
config: core.transformer.transformer_config.TransformerConfig,
pos_emb_group: torch.distributed.ProcessGroup,
pp_group: torch.distributed.ProcessGroup,
)#

All-reduce position_embeddings grad across encoder and decoder stages to ensure that position embeddings parameters stay in sync.

core.distributed.finalize_model_grads.reset_model_temporary_tensors(
config: core.transformer.transformer_config.TransformerConfig,
model: List[torch.nn.Module],
)#

Reset the temporary tensors of the model.

core.distributed.finalize_model_grads._update_router_expert_bias(
model: List[torch.nn.Module],
config: core.transformer.transformer_config.TransformerConfig,
)#

Update the expert bias of the router for a global batch. This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks

core.distributed.finalize_model_grads._allreduce_non_tensor_model_parallel_grads(
model: List[torch.nn.Module],
config: core.transformer.transformer_config.TransformerConfig,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

All-reduce both layernorm grads (for sequence parallelism) and gradients from modules with average_gradients_across_tp_domain=True across tensor-model-parallel ranks.

core.distributed.finalize_model_grads._allreduce_layernorm_grads#

None

core.distributed.finalize_model_grads.finalize_model_grads(
model: List[torch.nn.Module],
num_tokens: Optional[torch.Tensor] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, embedding grads across first and last pipeline stages (if not tied), scale gradients by num_tokens.