core.models.common.language_module.language_module#
Module Contents#
Classes#
Base language module that has common helper functions used across GPT, BERT etc. |
API#
- class core.models.common.language_module.language_module.LanguageModule(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
megatron.core.transformer.module.MegatronModuleBase language module that has common helper functions used across GPT, BERT etc.
- Parameters:
config (TransformerConfig) – Input transformer config for the model
pg_collection (ProcessGroupCollection) – Model communication process groups
Initialization
- _setup_mtp_cuda_graphs()#
Wrap
compute_mtp_single_stepwith a CudaGraphManager.Must be called by subclasses after
self.mtpis created.
- _is_in_embd_group()#
- _set_attention_backend()#
Set attention backend
Transformer engine works based on optout. By default all three attention backend flags are set to 1. So if the user choses a particular attention backend we set the other two to 0. If the user choses local, we set all 3 TE env variables to 0.
- compute_language_model_loss(
- labels: torch.Tensor,
- logits: torch.Tensor,
Computes the language model loss (Cross entropy across vocabulary)
- Parameters:
labels (Tensor) – The labels of dimension [batch size, seq length]
logits (Tensor) – The final logits returned by the output layer of the transformer model
- Returns:
Loss tensor of dimensions [batch size, sequence_length]
- Return type:
Tensor
- setup_embeddings_and_output_layer() None#
Sets up embedding layer in first stage and output layer in last stage.
This function initalizes word embeddings in the final stage when we are using pipeline parallelism and sharing word embeddings, and sets up param attributes on the embedding and output layers.
Parameter attributes set:
is_embedding_or_output_parameter: True for embedding + output layer weights. Used by decoupled_lr, Muon optimizer, and other Megatron features.is_embedding_parameter: True for MuP “embedding-class” parameters. Used by MuP for table-8 style optimizer grouping (base LR/eps for vector-like params).
- _scale_logits(logits: torch.Tensor) torch.Tensor#
Apply MuP output scaling to logits.
When MuP is enabled, scales logits by mup_output_mult (auto-set to 1/width_mult if left at default) to keep output variance stable across widths.
- Parameters:
logits (Tensor) – Raw logits from the output layer.
- Returns:
Scaled logits if MuP is enabled and mup_output_mult != 1.0, otherwise unchanged logits.
- Return type:
Tensor
Gets the embedding weight or output logit weights when share embedding and output weights set to True or when use Multi-Token Prediction (MTP).
- Returns:
During pre processing or MTP process it returns the input embeddings weight while during post processing it returns the final output layers weight
- Return type:
Tensor
- compute_mtp_single_step(
- hidden_states: torch.Tensor,
- next_token_ids: torch.Tensor,
- position_ids: torch.Tensor,
- depth: Optional[int] = None,
Compute a single MTP depth for speculative decoding.
This is called after speculative token verification to compute MTP predictions conditioned on verified tokens only.
- Parameters:
hidden_states (Tensor) – Hidden states at last accepted positions.
next_token_ids (Tensor) – Correct next token IDs [1, N].
position_ids (Tensor) – Position IDs for the next tokens [1, N].
depth (int, optional) – MTP depth index. Only needed when
mtp_use_repeated_layeris False (each depth uses a distinct layer). Omit for repeated-layer models so that a single CUDA graph can serve all depths.
- Returns:
(new_hidden_states, logits [N, 1, vocab_size]).
- Return type:
tuple
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: Tuple[Tuple[int, int, int]] = (),
- metadata: Optional[dict] = None,
Sharded state dict implementation that handles the output layer weights tying.
- Parameters:
prefix (str) – Module name prefix.
sharded_offsets (tuple) – PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]) – metadata controlling sharded state dict creation.
- Returns:
sharded state dict for the LanguageModel
- Return type:
- tie_embeddings_and_output_weights_state_dict(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- output_layer_weight_key: str,
- first_stage_word_emb_key: str,
- metadata: dict = {},
Ties the embedding and output weights in a given sharded state dict.
- Parameters:
sharded_state_dict (ShardedStateDict) – state dict with the weight to tie
output_layer_weight_key (str) – key of the output layer weight in the state dict. This entry will be replaced with a tied version
first_stage_word_emb_key (str) – this must be the same as the ShardedTensor.key of the first stage word embeddings.
Returns: None, acts in-place