core.models.common.language_module.language_module#

Module Contents#

Classes#

LanguageModule

Base language module that has common helper functions used across GPT, BERT etc.

API#

class core.models.common.language_module.language_module.LanguageModule(
config: megatron.core.transformer.transformer_config.TransformerConfig,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

Base language module that has common helper functions used across GPT, BERT etc.

Parameters:

Initialization

_is_in_embd_group()#
_set_attention_backend()#

Set attention backend

Transformer engine works based on optout. By default all three attention backend flags are set to 1. So if the user choses a particular attention backend we set the other two to 0. If the user choses local, we set all 3 TE env variables to 0.

compute_language_model_loss(
labels: torch.Tensor,
logits: torch.Tensor,
) torch.Tensor#

Computes the language model loss (Cross entropy across vocabulary)

Parameters:
  • labels (Tensor) – The labels of dimension [batch size, seq length]

  • logits (Tensor) – The final logits returned by the output layer of the transformer model

Returns:

Loss tensor of dimensions [batch size, sequence_length]

Return type:

Tensor

setup_embeddings_and_output_layer() None#

Sets up embedding layer in first stage and output layer in last stage.

This function initalizes word embeddings in the final stage when we are using pipeline parallelism and sharing word embeddings, and sets up param attributes on the embedding and output layers.

shared_embedding_or_output_weight() torch.Tensor#

Gets the emedding weight or output logit weights when share embedding and output weights set to True.

Returns:

During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight

Return type:

Tensor

sharded_state_dict(
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Sharded state dict implementation that handles the output layer weights tying.

Parameters:
  • prefix (str) – Module name prefix.

  • sharded_offsets (tuple) – PP related offsets, expected to be empty at this module level.

  • metadata (Optional[Dict]) – metadata controlling sharded state dict creation.

Returns:

sharded state dict for the LanguageModel

Return type:

ShardedStateDict

tie_embeddings_and_output_weights_state_dict(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
output_layer_weight_key: str,
first_stage_word_emb_key: str,
metadata: dict = {},
) None#

Ties the embedding and output weights in a given sharded state dict.

Parameters:
  • sharded_state_dict (ShardedStateDict) – state dict with the weight to tie

  • output_layer_weight_key (str) – key of the output layer weight in the state dict. This entry will be replaced with a tied version

  • first_stage_word_emb_key (str) – this must be the same as the ShardedTensor.key of the first stage word embeddings.

Returns: None, acts in-place