core.models.bert.bert_model#

Module Contents#

Classes#

BertModel

Transformer language model.

Functions#

get_te_version

Included for backwards compatibility.

API#

core.models.bert.bert_model.get_te_version()#

Included for backwards compatibility.

class core.models.bert.bert_model.BertModel(
config: megatron.core.transformer.transformer_config.TransformerConfig,
num_tokentypes: int,
transformer_layer_spec: megatron.core.transformer.spec_utils.ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal[learned_absolute, rope] = 'learned_absolute',
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[float] = None,
add_binary_head=True,
return_embeddings=False,
vp_stage: Optional[int] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: megatron.core.models.common.language_module.language_module.LanguageModule

Transformer language model.

Parameters:
  • config (TransformerConfig) – transformer config

  • num_tokentypes (int) – Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.

  • transformer_layer_spec (ModuleSpec) – Specifies module to use for transformer layers

  • vocab_size (int) – vocabulary size

  • max_sequence_length (int) – maximum size of sequence. This is used for positional embedding

  • pre_process (bool) – Include embedding layer (used with pipeline parallelism)

  • post_process (bool) – Include an output layer (used with pipeline parallelism)

  • parallel_output (bool) – Do not gather the outputs, keep them split across tensor parallel ranks

  • share_embeddings_and_output_weights (bool) – When True, input embeddings and output logit weights are shared. Defaults to False.

  • position_embedding_type (string) – Position embedding type. Options [‘learned_absolute’, ‘rope’]. Defaults is ‘learned_absolute’.

  • rotary_percent (float) – Percent of rotary dimension to use for rotary position embeddings. Defaults to 1.0 (100%). Ignored unless position_embedding_type is ‘rope’.

  • vp_stage (int) – Virtual pipeline stage.

Initialization

_sanity_check_attention_and_get_attn_mask_dimension() str#

We do some checks and return attention mask dimensions for self attention

Transformer engine library underwent a lot of change. So we need to change dimensions of the attention mask depending on the TE version. We also santiy check some arguments.

  1. If we use local version of attention dimension of the mask is [b,1,s,s]

  2. If we use transformer engine > 1.10 we support all 3 backends with padding mask and [b,1,s,s]

  3. If we use transformer engine >= 1.7 but less than 1.10 a ) Flash and Fused attention uses padding mask with [b,1,1,s] b ) Unfused attention works with arbitrary mask with [b,1,s,s]

  4. If we use transformer engine < 1.7 Flash and fused attention is not supported. Unfused attention will work with padding mask [b,1,s,s]

Default if you dont set any NVTE_ATTN flag will it will just use the fused path for transformer engine version >= 1.7 and unfused path for other

Parameters:

transformer_layer_spec (ModuleSpec) – The transformer layer spec

Returns:

A string showing the format of the attn mask dimensions

Return type:

str

bert_extended_attention_mask(
attention_mask: torch.Tensor,
) torch.Tensor#

Creates the extended attention mask

Converts the attention mask of dimension [batch size, 1, seq len] to [batch size, 1, seq len, seq len] or [batch size, 1, 1, seq_len] and makes it binary

Parameters:

attention_mask (Tensor) – The input attention mask

Returns:

The extended binary attention mask

Return type:

Tensor

bert_position_ids(token_ids)#

Position ids for bert model

set_input_tensor(input_tensor: torch.Tensor) None#

Sets input tensor to the model.

See megatron.model.transformer.set_input_tensor()

Parameters:

input_tensor (Tensor) – Sets the input tensor for the model.

forward(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
tokentype_ids: torch.Tensor = None,
lm_labels: torch.Tensor = None,
inference_context=None,
*,
inference_params: Optional[megatron.core.inference.contexts.BaseInferenceContext] = None,
)#

Forward function of BERT model

Forward function of the BERT Model This function passes the input tensors through the embedding layer, and then the encoder and finally into the post processing layer (optional).

It either returns the Loss values if labels are given or the final hidden units