models.t5 package#
Submodules#
models.t5.t5_model module#
- class core.models.T5.t5_model.T5LMHead(*args: Any, **kwargs: Any)#
Bases:
MegatronModule
Masked LM head for T5
- Parameters:
config (TransformerConfig) – transformer config
parallel_output (bool) – wether output logits being distributed or not.
vocab_size (int) – vocabulary size
pre_process (bool) – Include embedding layer
share_embeddings_and_output_weights (bool) – When True, input embeddings and output logit weights are shared.
- forward(
- hidden_states: torch.Tensor,
- word_embeddings_weight: torch.Tensor,
Forward pass.
- Parameters:
hidden_states (Tensor) – output hidden states from decoder
word_embeddings_weight (Tensor) – word embedding weight
- Returns:
logits tensor
- Return type:
Tensor
- class core.models.T5.t5_model.T5Model(*args: Any, **kwargs: Any)#
Bases:
LanguageModule
T5 Language model.
- Parameters:
config (TransformerConfig) – transformer config
encoder_config (TransformerConfig) – encoder transformer config
transformer_encoder_layer_spec (ModuleSpec) – transformer layer customization specs for encoder
transformer_decoder_layer_spec (ModuleSpec) – transformer layer customization specs for decoder
vocab_size (int) – vocabulary size
max_sequence_length (int) – maximum size of sequence. This is used for positional embedding
pre_process (bool) – Include embedding layer (used with pipeline parallelism)
post_process (bool) – Include an output layer (used with pipeline parallelism)
fp16_lm_cross_entropy (bool, optional) – Defaults to False
parallel_output (bool) – Do not gather the outputs, keep them split across tensor parallel ranks
share_embeddings_and_output_weights (bool) – When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (string) – Position embedding type. Options [‘learned_absolute’, ‘rope’]. Defaults is ‘learned_absolute’.
rotary_percent (float) – Percent of rotary dimension to use for rotary position embeddings. Defaults to 1.0 (100%). Ignored unless position_embedding_type is ‘rope’.
seq_len_interpolation_factor (float) – scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None.
add_encoder (bool) – Create the encoder (used with pipeline parallelism). When using pipelining, the encoder will only be created on a subset of the pipeline ranks.
add_decoder (bool) – Include an output layer (used with pipeline parallelism). As with add_encoder, when using this model and pipelining, the decoder will only be created on a subset of the pipeline ranks.
- forward(
- encoder_input_ids: torch.Tensor,
- decoder_input_ids: torch.Tensor,
- encoder_attn_mask: torch.Tensor,
- decoder_attn_mask: torch.Tensor,
- encoder_decoder_attn_mask: torch.Tensor,
- lm_labels: torch.Tensor | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- output_encoder_hidden_only: bool = False,
- inference_context: megatron.core.inference.contexts.BaseInferenceContext | None = None,
- packed_seq_params: megatron.core.packed_seq_params.PackedSeqParams | None = None,
- *,
- inference_params: megatron.core.inference.contexts.BaseInferenceContext | None = None,
Forward pass.
- Parameters:
encoder_input_ids (Tensor) – input ids for encoder
decoder_input_ids (Tensor) – input ids for decoder
encoder_attn_mask (Tensor) – self-attention mask for encoder
decoder_attn_mask (Tensor) – self-attention mask for decoder
encoder_decoder_attn_mask (Tensor) – cross-attention mask between encoder and decoder
lm_labels (Tensor) – labels for decoder output
inference_context (BaseInferenceContext) – relevant arguments for inferencing
- Returns:
loss tensor
- Return type:
Tensor
- set_input_tensor(input_tensor)#
See megatron.model.transformer.set_input_tensor()
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: Tuple[Tuple[int, int, int]] = (),
- metadata: dict | None = None,
Sharded state dict implementation handling duplication of encoder and decoder layers.
Some layers (output, embedding) are shared between the encoder and decoder. This method sets the replica_id for them to ensure there is only one layer instance with replica_id (0, 0, 0).
- Parameters:
prefix (str) – Module name prefix.
sharded_offsets (tuple) – PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]) – metadata controlling sharded state dict creation.
- Returns:
sharded state dict for the T5Model
- Return type:
ShardedStateDict
Function to share the input embeddings and output logit weights.
- core.models.T5.t5_model.t5_extended_attention_mask(
- attention_mask_list: List[torch.Tensor],
Creates the extended attention mask
Converts the attention mask of dimension [batch size, seq_len, seq_len] to [batch size, 1, seq_len, seq_len]
- Parameters:
attention_mask (Tensor) – The input attention mask
- Returns:
The extended binary attention mask
- Return type:
Tensor
- core.models.T5.t5_model.t5_position_ids(token_ids: torch.Tensor) torch.Tensor #
Calculate position ids from token ids :param token_ids: input tokens :type token_ids: Tensor
- Returns:
position ids
- Return type:
Tensor