core.models.common.language_module.language_module#

Module Contents#

Classes#

LanguageModule

Base language module that has common helper functions used across GPT, BERT etc.

API#

class core.models.common.language_module.language_module.LanguageModule(
config: megatron.core.transformer.transformer_config.TransformerConfig,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

Base language module that has common helper functions used across GPT, BERT etc.

Parameters:

Initialization

_is_in_embd_group()#
_set_attention_backend()#

Set attention backend

Transformer engine works based on optout. By default all three attention backend flags are set to 1. So if the user choses a particular attention backend we set the other two to 0. If the user choses local, we set all 3 TE env variables to 0.

compute_language_model_loss(
labels: torch.Tensor,
logits: torch.Tensor,
) torch.Tensor#

Computes the language model loss (Cross entropy across vocabulary)

Parameters:
  • labels (Tensor) – The labels of dimension [batch size, seq length]

  • logits (Tensor) – The final logits returned by the output layer of the transformer model

Returns:

Loss tensor of dimensions [batch size, sequence_length]

Return type:

Tensor

setup_embeddings_and_output_layer() None#

Sets up embedding layer in first stage and output layer in last stage.

This function initalizes word embeddings in the final stage when we are using pipeline parallelism and sharing word embeddings, and sets up param attributes on the embedding and output layers.

Parameter attributes set:

  • is_embedding_or_output_parameter: True for embedding + output layer weights. Used by decoupled_lr, Muon optimizer, and other Megatron features.

  • is_embedding_parameter: True for MuP “embedding-class” parameters. Used by MuP for table-8 style optimizer grouping (base LR/eps for vector-like params).

_scale_logits(logits: torch.Tensor) torch.Tensor#

Apply MuP output scaling to logits.

When MuP is enabled, scales logits by mup_output_mult (auto-set to 1/width_mult if left at default) to keep output variance stable across widths.

Parameters:

logits (Tensor) – Raw logits from the output layer.

Returns:

Scaled logits if MuP is enabled and mup_output_mult != 1.0, otherwise unchanged logits.

Return type:

Tensor

shared_embedding_or_output_weight() torch.Tensor#

Gets the embedding weight or output logit weights when share embedding and output weights set to True or when use Multi-Token Prediction (MTP).

Returns:

During pre processing or MTP process it returns the input embeddings weight while during post processing it returns the final output layers weight

Return type:

Tensor

sharded_state_dict(
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Sharded state dict implementation that handles the output layer weights tying.

Parameters:
  • prefix (str) – Module name prefix.

  • sharded_offsets (tuple) – PP related offsets, expected to be empty at this module level.

  • metadata (Optional[Dict]) – metadata controlling sharded state dict creation.

Returns:

sharded state dict for the LanguageModel

Return type:

ShardedStateDict

tie_embeddings_and_output_weights_state_dict(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
output_layer_weight_key: str,
first_stage_word_emb_key: str,
metadata: dict = {},
) None#

Ties the embedding and output weights in a given sharded state dict.

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict with the weight to tie

  • output_layer_weight_key (str) – key of the output layer weight in the state dict. This entry will be replaced with a tied version

  • first_stage_word_emb_key (str) – this must be the same as the ShardedTensor.key of the first stage word embeddings.

Returns: None, acts in-place