core.models.bert.bert_lm_head#

Module Contents#

Classes#

BertLMHead

Masked LM head for Bert.

API#

class core.models.bert.bert_lm_head.BertLMHead(
hidden_size: int,
config: megatron.core.transformer.transformer_config.TransformerConfig,
)#

Bases: megatron.core.transformer.module.MegatronModule

Masked LM head for Bert.

Parameters:
  • hidden_size – hidden size

  • config (TransformerConfig) – TransformerConfig object

Initialization

forward(hidden_states: torch.Tensor) torch.Tensor#

forward pass