core.distributed.finalize_model_grads#
Module Contents#
Functions#
Unshards the input tensor if it is a DTensor and otherwise returns the tensor unmodified. |
|
Reshards the input tensor to match the sharding configuration of the reference tensor if the reference tensor is a DTensor. Otherwise, returns the reference tensor unmodified. |
|
All-reduce conditional embedding grads. |
|
Return the shared word-embedding weight if it is duplicated across stages. |
|
Return the position-embedding weight tensor from the given model module. |
|
All-reduce word-embedding gradients across the first and last PP stages. |
|
Unified helper to all-reduce embedding parameters across pipeline stages. |
|
All-reduce position_embeddings grad across encoder and decoder stages to ensure that position embeddings parameters stay in sync. |
|
Reset the temporary tensors of the model. |
|
Update the expert bias of the router for a global batch. This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks |
|
All-reduce both layernorm grads (for sequence parallelism) and gradients from modules with average_gradients_across_tp_domain=True across tensor-model-parallel ranks. |
|
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by |
Data#
API#
- core.distributed.finalize_model_grads._get_main_grad_attr(param: torch.nn.Parameter)#
- core.distributed.finalize_model_grads._unshard_if_dtensor(
- tensor: Union[torch.Tensor, torch.distributed._tensor.DTensor],
Unshards the input tensor if it is a DTensor and otherwise returns the tensor unmodified.
- Parameters:
tensor (Union[torch.Tensor, DTensor]) – The tensor to potentially unshard.
- Returns:
An unsharded version of the input tensor if it is a DTensor, or the input tensor unmodified if it is not a DTensor.
- core.distributed.finalize_model_grads._reshard_if_dtensor(
- tensor_to_shard: torch.Tensor,
- reference_tensor: Union[torch.Tensor, torch.distributed._tensor.DTensor],
Reshards the input tensor to match the sharding configuration of the reference tensor if the reference tensor is a DTensor. Otherwise, returns the reference tensor unmodified.
- Parameters:
tensor_to_shard (torch.Tensor) – The tensor to be potentially sharded.
reference_tensor (Union[torch.Tensor, DTensor]) – The reference tensor for the sharding configuration.
- Returns:
The sharded tensor matching the reference tensor’s configuration, or the reference tensor itself if it is not a DTensor.
- Return type:
Union[torch.Tensor, DTensor]
- core.distributed.finalize_model_grads._allreduce_conditional_embedding_grads(
- model: List[torch.nn.Module],
- config: core.transformer.transformer_config.TransformerConfig,
- pp_group: Optional[torch.distributed.ProcessGroup] = None,
All-reduce conditional embedding grads.
Reduce grads across all the pp stages to ensure that parameters of the conditional embedders (e.g., timestep embedder, FPS embedder, label embedder) stay in sync. This is for the models with replicated embedders on each PP / VPP rank, like diffusion models.
- model_module: torch.nn.Module,
- config: core.transformer.transformer_config.TransformerConfig,
Return the shared word-embedding weight if it is duplicated across stages.
- Parameters:
model_module – The model module from which to extract the word-embedding weight.
config – Transformer config.
- Returns:
The shared embedding or output weight if available; otherwise
None.
- core.distributed.finalize_model_grads._get_position_embedding_weight(
- model_module: torch.nn.Module,
Return the position-embedding weight tensor from the given model module.
- Parameters:
model_module – The model module that owns the position-embedding parameter.
- Returns:
The position-embedding weight tensor.
- core.distributed.finalize_model_grads._allreduce_word_embedding_grads(
- model: List[torch.nn.Module],
- config: core.transformer.transformer_config.TransformerConfig,
- embd_group: Optional[torch.distributed.ProcessGroup] = None,
- pp_group: Optional[torch.distributed.ProcessGroup] = None,
All-reduce word-embedding gradients across the first and last PP stages.
This ensures that the
word_embeddingsparameters stay in sync when they are shared between the input and output layers.- Parameters:
model – A list containing the pipeline chunks that constitute the model on the current rank (including any virtual pipeline chunks).
config – Transformer configuration. Used for edge cases like MTP where embeddings might be shared differently.
embd_group – The process group over which to all-reduce the word-embedding gradients. If
None, it will be looked up based on the current pipeline model parallel group.pp_group – The pipeline parallel process group used to identify first/last stages. If
None, it will be looked up.
- core.distributed.finalize_model_grads._allreduce_embedding_grad(
- model: List[torch.nn.Module],
- embd_group: torch.distributed.ProcessGroup,
- pp_group: torch.distributed.ProcessGroup,
- weight_getter: Callable[[torch.nn.Module], Optional[torch.nn.Parameter]],
- skip_if_none: bool = True,
Unified helper to all-reduce embedding parameters across pipeline stages.
- Parameters:
model (List[torch.nn.Module]) – A list of model chunks (PP/VPP).
embd_group (torch.distributed.ProcessGroup) – The process group over which to reduce.
pp_group (torch.distributed.ProcessGroup) – The pipeline parallel process group for first/last stage detection.
weight_getter (Callable[[torch.nn.Module], Optional[torch.nn.Parameter]]) – A function that takes the pre-process model chunk and returns the parameter to be reduced (or
Noneif not applicable).skip_if_none (bool, optional) – If True, quietly returns when the parameter or its gradient is
None. Defaults to True.
- core.distributed.finalize_model_grads._allreduce_position_embedding_grads(
- model: List[torch.nn.Module],
- config: core.transformer.transformer_config.TransformerConfig,
- pos_emb_group: torch.distributed.ProcessGroup,
- pp_group: torch.distributed.ProcessGroup,
All-reduce position_embeddings grad across encoder and decoder stages to ensure that position embeddings parameters stay in sync.
- core.distributed.finalize_model_grads.reset_model_temporary_tensors(
- config: core.transformer.transformer_config.TransformerConfig,
- model: List[torch.nn.Module],
Reset the temporary tensors of the model.
- core.distributed.finalize_model_grads._update_router_expert_bias(
- model: List[torch.nn.Module],
- config: core.transformer.transformer_config.TransformerConfig,
Update the expert bias of the router for a global batch. This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks
- core.distributed.finalize_model_grads._allreduce_non_tensor_model_parallel_grads(
- model: List[torch.nn.Module],
- config: core.transformer.transformer_config.TransformerConfig,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
All-reduce both layernorm grads (for sequence parallelism) and gradients from modules with average_gradients_across_tp_domain=True across tensor-model-parallel ranks.
- core.distributed.finalize_model_grads._allreduce_layernorm_grads#
None
- core.distributed.finalize_model_grads.finalize_model_grads(
- model: List[torch.nn.Module],
- num_tokens: Optional[torch.Tensor] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, embedding grads across first and last pipeline stages (if not tied), scale gradients by
num_tokens.