NeMo Megatron API#
Pretraining Model Classes#
Customization Model Classes#
Modules#
- class nemo.collections.nlp.modules.common.megatron.module.MegatronModule(*args: Any, **kwargs: Any)[source]#
Bases:
Module
Megatron specific extensions of torch Module with support for pipelining.
- class nemo.collections.nlp.modules.common.megatron.module.Float16Module(*args: Any, **kwargs: Any)[source]#
Bases:
MegatronModule
- class nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder.MegatronTokenLevelEncoderDecoderModule(*args: Any, **kwargs: Any)[source]#
Bases:
MegatronModule
,AdapterModuleMixin
Token-based (input/output is tokens) encoder-decoder model (e.g. T5 Language model.)
- forward(enc_input_ids=None, enc_attn_mask=None, dec_input_ids=None, dec_attn_mask=None, token_type_ids=None, labels=None, batch_data=None, enc_output=None, enc_output_attn_mask=None, enc_input=None, output_enc_hidden_only=False)[source]#
Return value is per token / per dimension (i.e., non collapsed loss value)
- class nemo.collections.nlp.modules.common.megatron.retrieval_token_level_encoder_decoder.MegatronRetrievalTokenLevelEncoderDecoderModule(*args: Any, **kwargs: Any)[source]#
Bases:
MegatronModule
Token-based (input/output is tokens) retrieval encoder-decoder model
- forward(input_ids, input_attn_mask, retrieved_ids, retrieved_attn_mask, token_type_ids=None, labels=None, input_emb=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, neighbors=None, position_ids=None)[source]#
Return value is per token / per dimension (i.e., non collapsed loss value)
Datasets#
- class nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset.GPTDataset(*args: Any, **kwargs: Any)[source]#
Bases:
Dataset
Adapter Mixin Class#
- class nemo.collections.nlp.parts.mixins.nlp_adapter_mixins.NLPAdapterModelMixin(*args, **kwargs)[source]#
Bases:
object
NLP Adapter Mixin that can augment any transformer-based model with Adapter module support. This mixin class should be used only with a top level ModelPT subclass, that includes either a model or an enc_dec_model submodule. This mixin class adds several utility methods to add, load and save adapters.
An Adapter module is any Pytorch nn.Module that possess a few properties :
It’s input and output dimension are the same, while the hidden dimension need not be the same.
The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter yields the original output.
This mixin class aims to integrate with PEFT, which is one or more adapters modules. The two features of PEFT, layer selection and weight tying, are also supported in this mixin class.
- add_adapter(peft_cfgs: Union[PEFTConfig, List[PEFTConfig]])[source]#
High level API to add one or more adapter modules to the model, and freeze the base weights This method supports adding adapter modules from PEFTConfig or list of PEFTConfig. It would add corresponding adapter modules. Layer selection and weight tying would be applied if it’s in PEFTConfig
- Parameters
peft_cfgs – One or more PEFTConfig objects that specify the PEFT method configuration
- setup_optimizer_param_groups()[source]#
ModelPT override. Optimizer will get self._optimizer_param_groups. Makes two optimizer param groups, one for the frozen model params and one for the prompt-table/prompt-encoder params. The learning rate for the frozen model’s params will always be zero effectively freezing the model’s params but still allowing for the needed gradients to be passed around in pipeline parallel models. The prompt-encoder and/or prompt table will use the learning rate set by the user.
- load_adapters(filepath: str, peft_cfgs: Optional[Union[PEFTConfig, List[PEFTConfig]]] = None, map_location: Optional[str] = None)[source]#
Utility method that restores only the adapter module(s), and not the entire model itself. This allows the sharing of adapters which are often just a fraction of the size of the full model, enabling easier deliver.
Note
During restoration, assumes that the model does not currently already have one or more adapter modules.
- Parameters
filepath – Filepath of the .ckpt or .nemo file.
peft_cfgs – One or more PEFTConfig objects that specify the PEFT method configuration. If none, will infer from the .nemo checkpoint
map_location – Pytorch flag, where to place the adapter(s) state dict(s).
- classmethod merge_cfg_with(path: str, cfg: omegaconf.DictConfig) omegaconf.DictConfig [source]#
Merge a given configuration dictionary cfg with the configuration dictionary obtained from restoring a MegatronGPTSFTModel or MegatronT5SFTModel at the specified path.
- Parameters
path (str) – The path to the SFT model checkpoint to be restored.
cfg (DictConfig) – The configuration dictionary to merge.
- Returns
The merged configuration dictionary.
- Return type
DictConfig
Examples
>>> path = "/path/to/model/checkpoint" >>> cfg = DictConfig({"model": {"key": "value"}, "trainer": {"precision": 16}}) >>> merged_cfg = merge_cfg_with(path, cfg)
Notes
The function resolves variables within the cfg dictionary using OmegaConf.resolve.
Keys in cfg.model will override the corresponding keys in the output dictionary.
If “train_ds” exists in cfg.model.data, it updates micro_batch_size and global_batch_size.
If cfg.trainer contains a “precision” key, it updates output.precision.
- classmethod merge_inference_cfg(path: str, cfg: omegaconf.DictConfig) omegaconf.DictConfig [source]#
Generate a configuration dictionary by a given configuration dictionary cfg with the configuration dictionary obtained from restoring a MegatronGPTSFTModel or MegatronT5SFTModel at the specified path and modify cfg for inference
- Parameters
path (str) – The path to the SFT model checkpoint to be restored.
cfg (DictConfig) – The configuration dictionary to modify for inference.
- Returns
The configuration dictionary for inference.
- Return type
DictConfig
Examples
>>> path = "/path/to/model/checkpoint" >>> cfg = DictConfig({"model": {"key": "value"}, "trainer": {"precision": 16}}) >>> merged_cfg = merge_inference_cfg(path, cfg)
Notes
“precision” and “test_ds” from cfg will override the corresponding keys in the output dictionary
“activations_checkpoint” will be ovrrided to None in the output dictionary
“use_flash_attention” will be True if in one of the configuration dictionarys is True
“seq_len_interpolation_factor” will be overrided from cfg if it’s not None from checkpoint