NeMo Models

Basics

NeMo Models contain everything needed to train and reproduce Conversational AI model:

  • neural network architectures

  • datasets/data loaders

  • data preprocessing/postprocessing

  • data augmentors

  • optimizers and schedulers

  • tokenizers

  • language models

NeMo uses Hydra for configuring both NeMo models and the PyTorch Lightning Trainer.

Note

Every NeMo model has an example configuration file and training script that can be found here.

The end result of using NeMo, Pytorch Lightning, and Hydra is that NeMo models all have the same look and feel and are also fully compatible with the PyTorch ecosystem.

Pretrained

NeMo comes with many pretrained models for each of our collections: ASR, NLP, and TTS.

Every pretrained NeMo model can be downloaded and used with the from_pretrained() method.

As an example we can instantiate QuartzNet with the following:

import nemo.collections.asr as nemo_asr

model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name="QuartzNet15x5Base-En")

To see all available pretrained models for a specific NeMo model use the list_available_models() method.

nemo_asr.model.EncDecCTCModel.list_available_models()

For detailed information on the available pretrained models, please the the collections documentation: Automatic Speech Recognition (ASR), Natural Language Processing (NLP), and Speech Synthesis (TTS).

Training

NeMo leverages PyTorch Lightning for model training. PyTorch Lightning lets NeMo decouple the Conversational AI code from the PyTorch training code. This means that NeMo users can focus on their domain (ASR, NLP, TTS) and building complex AI applications without having to rewrite boiler plate code for PyTorch training.

When using PyTorch Lightning, NeMo users can automatically train with:

  • multi-GPU/multi-node

  • mixed precision

  • model checkpointing

  • logging

  • early stopping

  • and more

The two main aspects of the Lightning API are the LightningModule and the Trainer.

PyTorch Lightning LightningModule

Every NeMo model is a LightningModule which is an nn.module. This means that NeMo models are compatible with the PyTorch ecosystem and can be plugged into existing PyTorch workflows.

Creating a NeMo Model is similar to any other PyTorch workflow. We start by initializing our model architecture and then define the forward pass:

class TextClassificationModel(NLPModel, Exportable):
    ...
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes the BERTTextClassifier model."""
        ...
        super().__init__(cfg=cfg, trainer=trainer)

        # instantiate a BERT based encoder
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=cfg.language_model.config,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=cfg.tokenizer.vocab_file,
        )

        # instantiate the FFN for classification
        self.classifier = SequenceClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=cfg.dataset.num_classes,
            num_layers=cfg.classifier_head.num_output_layers,
            activation='relu',
            log_softmax=False,
            dropout=cfg.classifier_head.fc_dropout,
            use_transformer_init=True,
            idx_conditioned_on=0,
        )
def forward(self, input_ids, token_type_ids, attention_mask):
    """
    No special modification required for Lightning, define it as you normally would
    in the `nn.Module` in vanilla PyTorch.
    """
    hidden_states = self.bert_model(
        input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
    )
    logits = self.classifier(hidden_states=hidden_states)
    return logits

The LightningModule organizes PyTorch code so that across all NeMo models we have a similar look and feel. For example, the training logic can be found in training_step:

def training_step(self, batch, batch_idx):
    """
    Lightning calls this inside the training loop with the data from the training dataloader
    passed in as `batch`.
    """
    # forward pass
    input_ids, input_type_ids, input_mask, labels = batch
    logits = self.forward(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)

    train_loss = self.loss(logits=logits, labels=labels)

    lr = self._optimizer.param_groups[0]['lr']

    self.log('train_loss', train_loss)
    self.log('lr', lr, prog_bar=True)

    return {
        'loss': train_loss,
        'lr': lr,
    }

While validation logic can be found in validation_step:

def validation_step(self, batch, batch_idx):
    """
    Lightning calls this inside the validation loop with the data from the validation dataloader
    passed in as `batch`.
    """
    if self.testing:
        prefix = 'test'
    else:
        prefix = 'val'

    input_ids, input_type_ids, input_mask, labels = batch
    logits = self.forward(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)

    val_loss = self.loss(logits=logits, labels=labels)

    preds = torch.argmax(logits, axis=-1)

    tp, fn, fp, _ = self.classification_report(preds, labels)

    return {'val_loss': val_loss, 'tp': tp, 'fn': fn, 'fp': fp}

PyTorch Lightning then handles all of the boiler plate code needed for training. Virtually any aspect of training can be customized via PyTorch Lightning hooks, Plugins, callbacks, or by overriding methods.

Please see the Automatic Speech Recognition (ASR), Natural Language Processing (NLP), and Speech Synthesis (TTS), pages for domain-specific documentation.

PyTorch Lightning Trainer

Since every NeMo Model is a LightningModule, we can automatically take advantage of the PyTorch Lightning Trainer. Every NeMo example training script uses the Trainer object to fit the model.

First instantiate the model and trainer and then call .fit:

# We first instantiate the trainer based on the model configuration.
# See the model configuration documentation for details.
trainer = pl.Trainer(**cfg.trainer)

# Then pass the model configuration and trainer object into the NeMo model
model = TextClassificationModel(cfg.model, trainer=trainer)

# Now we can train with by calling .fit
trainer.fit(model)

# Or we can run the test loop on test data by calling
trainer.test(model=model)

All trainer flags can be set from from the NeMo Configuration, see below for more details on model configuration.

Configuration

Hydra is an open-source Python framework that simplifies configuration for complex applications that must bring together many different software libraries. Conversational AI model training is a great example of such an application. To train a Conversational AI model, we must be able to configure:

  • neural network architectures

  • training and optimization algorithms

  • data pre/post processing

  • data augmentation

  • experiment logging/visualization

  • model checkpointing

Please see the Hydra Tutorials for an introduction to using Hydra.

With Hydra we can configure everything needed for NeMo with three interfaces:

  • Command Line (CLI)

  • Configuration Files (YAML)

  • Dataclasses (Python)

YAML

NeMo provides YAML configuration files for all of our example training scripts. YAML files make it easy to experiment with different model and training configurations.

Every NeMo example YAML has the same underlying configuration structure:

  • trainer

  • exp_manager

  • model

Model configuration always contain train_ds, validation_ds, test_ds, and optim. Model architectures vary across domains so please see the ASR, NLP, and TTS Collections documentation for more detailed information on Model architecture configuration.

A NeMo configuration file should look something like this:

# PyTorch Lightning Trainer configuration
# any argument of the Trainer object can be set here
trainer:
    gpus: 1 # number of gpus per node
    num_nodes: 1 # number of nodes
    max_epochs: 10 # how many training epochs to run
    val_check_interval: 1.0 # run validation after every epoch

# Experiment logging configuration
exp_manager:
    exp_dir: /path/to/my/nemo/experiments
    name: name_of_my_experiment
    create_tensorboard_logger: True
    create_wandb_logger: True

# Model configuration
# model network architecture, train/val/test datasets, data augmentation, and optimization
model:
    train_ds:
        manifest_filepath: /path/to/my/train/manifest.json
        batch_size: 256
        shuffle: True
    validation_ds:
        manifest_filepath: /path/to/my/validation/manifest.json
        batch_size: 32
        shuffle: False
    test_ds:
        manifest_filepath: /path/to/my/test/manifest.json
        batch_size: 32
        shuffle: False
    optim:
        name: novograd
        lr: .01
        betas: [0.8, 0.5]
        weight_decay: 0.001
    # network architecture can vary greatly depending on the domain
    encoder:
        ...
    decoder:
        ...

More specific details about configuration files for each collection can be found on the following pages:

NeMo ASR Configuration Files

CLI

With NeMo and Hydra, every aspect of model training can modified from the command line. This is extremely helpful for running lots of experiments on compute clusters or for quickly testing parameters while developing.

All NeMo examples come with instructions on how to run the training/inference script from the command line, see here for an example.

With Hydra, arguments are set using the = operator:

python examples/asr/speech_to_text.py \
    model.train_ds.manifest_filepath=/path/to/my/train/manifest.json \
    model.validation_ds.manifest_filepath=/path/to/my/validation/manifest.json \
    trainer.gpus=2 \
    trainer.max_epochs=50

We can use the + operator to add arguments from the CLI:

python examples/asr/speech_to_text.py \
    model.train_ds.manifest_filepath=/path/to/my/train/manifest.json \
    model.validation_ds.manifest_filepath=/path/to/my/validation/manifest.json \
    trainer.gpus=2 \
    trainer.max_epochs=50 \
    +trainer.fast_dev_run=true

We can use the ~ operator to remove configurations:

python examples/asr/speech_to_text.py \
    model.train_ds.manifest_filepath=/path/to/my/train/manifest.json \
    model.validation_ds.manifest_filepath=/path/to/my/validation/manifest.json \
    ~model.test_ds \
    trainer.gpus=2 \
    trainer.max_epochs=50 \
    +trainer.fast_dev_run=true

We can specify configuration files using the --config-path and --config-name flags:

python examples/asr/speech_to_text.py \
    --config-path=conf \
    --config-name=quartznet_15x5 \
    model.train_ds.manifest_filepath=/path/to/my/train/manifest.json \
    model.validation_ds.manifest_filepath=/path/to/my/validation/manifest.json \
    ~model.test_ds \
    trainer.gpus=2 \
    trainer.max_epochs=50 \
    +trainer.fast_dev_run=true

Dataclasses

Dataclasses allow NeMo to ship model configurations as part of the NeMo library and also enables pure Python configuration of NeMo models. With Hydra, dataclasses can be used to create structured configs for the Conversational AI application.

As an example, see the code block below for an Attenion is All You Need machine translation model. The model configuration can be instantiated and modified like any Python Dataclass.

from nemo.collections.nlp.models.machine_translation.mt_enc_dec_config import AAYNBaseConfig

cfg = AAYNBaseConfig()

# modify the number of layers in the encoder
cfg.encoder.num_layers = 8

# modify the training batch size
cfg.train_ds.tokens_in_batch = 8192

Note

Configuration with Hydra always has the following precedence CLI > YAML > Dataclass

Optimization

Optimizers and learning rate schedules are configurable across all NeMo models and have their own namespace. Here is a sample YAML configuration for a Novograd optimizer with Cosine Annealing learning rate schedule.

optim:
    name: novograd
    lr: 0.01

    # optimizer arguments
    betas: [0.8, 0.25]
    weight_decay: 0.001

    # scheduler setup
    sched:
    name: CosineAnnealing

    # Optional arguments
    max_steps: null # computed at runtime or explicitly set here
    monitor: val_loss
    reduce_on_plateau: false

    # scheduler config override
    warmup_steps: 1000
    warmup_ratio: null
    min_lr: 1e-9:

Note

NeMo Examples has optimizer and scheduler configurations for every NeMo model.

Optimizers can be configured from the CLI as well:

python examples/asr/speech_to_text.py \
    --config-path=conf \
    --config-name=quartznet_15x5 \
    ...
    # train with the adam optimizer
    model.optim=adam \
    # change the learning rate
    model.optim.lr=.0004 \
    # modify betas
    model.optim.betas=[.8, .5]

Optimizers

name corresponds to the lowercase name of the optimizer. The list of available optimizers can be found by

from nemo.core.optim.optimizers import AVAILABLE_OPTIMIZERS

for name, opt in AVAILABLE_OPTIMIZERS.items():
    print(f'name: {name}, opt: {opt}')
name: sgd opt: <class 'torch.optim.sgd.SGD'>
name: adam opt: <class 'torch.optim.adam.Adam'>
name: adamw opt: <class 'torch.optim.adamw.AdamW'>
name: adadelta opt: <class 'torch.optim.adadelta.Adadelta'>
name: adamax opt: <class 'torch.optim.adamax.Adamax'>
name: adagrad opt: <class 'torch.optim.adagrad.Adagrad'>
name: rmsprop opt: <class 'torch.optim.rmsprop.RMSprop'>
name: rprop opt: <class 'torch.optim.rprop.Rprop'>
name: novograd opt: <class 'nemo.core.optim.novograd.Novograd'>

Optimizer Params

Optimizers params can vary between optimizers but the lr param is required for all optimizers. To see the available params for an optimizer we can look at its corresponding dataclass.

from nemo.core.config.optimizers import NovogradParams

print(NovogradParams())
NovogradParams(lr='???', betas=(0.95, 0.98), eps=1e-08, weight_decay=0, grad_averaging=False, amsgrad=False, luc=False, luc_trust=0.001, luc_eps=1e-08)

'???' indicates that the lr argument is required.

Register Optimizer

Register a new optimizer to be used with NeMo with:

nemo.core.optim.optimizers.register_optimizer(name: str, optimizer: torch.optim.optimizer.Optimizer, optimizer_params: nemo.core.config.optimizers.OptimizerParams)[source]

Checks if the optimizer name exists in the registry, and if it doesnt, adds it.

This allows custom optimizers to be added and called by name during instantiation.

Parameters
  • name – Name of the optimizer. Will be used as key to retrieve the optimizer.

  • optimizer – Optimizer class

  • optimizer_params – The parameters as a dataclass of the optimizer

Learning Rate Schedulers

Learning rate schedulers can be optionally configured under the optim.sched namespace.

name corresponds to the name of the learning rate schedule. The list of available schedulers can be found by

from nemo.core.optim.lr_scheduler import AVAILABLE_SCHEDULERS

for name, opt in AVAILABLE_SCHEDULERS.items():
    print(f'name: {name}, schedule: {opt}')
name: WarmupPolicy, schedule: <class 'nemo.core.optim.lr_scheduler.WarmupPolicy'>
name: WarmupHoldPolicy, schedule: <class 'nemo.core.optim.lr_scheduler.WarmupHoldPolicy'>
name: SquareAnnealing, schedule: <class 'nemo.core.optim.lr_scheduler.SquareAnnealing'>
name: CosineAnnealing, schedule: <class 'nemo.core.optim.lr_scheduler.CosineAnnealing'>
name: NoamAnnealing, schedule: <class 'nemo.core.optim.lr_scheduler.NoamAnnealing'>
name: WarmupAnnealing, schedule: <class 'nemo.core.optim.lr_scheduler.WarmupAnnealing'>
name: InverseSquareRootAnnealing, schedule: <class 'nemo.core.optim.lr_scheduler.InverseSquareRootAnnealing'>
name: SquareRootAnnealing, schedule: <class 'nemo.core.optim.lr_scheduler.SquareRootAnnealing'>
name: PolynomialDecayAnnealing, schedule: <class 'nemo.core.optim.lr_scheduler.PolynomialDecayAnnealing'>
name: PolynomialHoldDecayAnnealing, schedule: <class 'nemo.core.optim.lr_scheduler.PolynomialHoldDecayAnnealing'>
name: StepLR, schedule: <class 'torch.optim.lr_scheduler.StepLR'>
name: ExponentialLR, schedule: <class 'torch.optim.lr_scheduler.ExponentialLR'>
name: ReduceLROnPlateau, schedule: <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>
name: CyclicLR, schedule: <class 'torch.optim.lr_scheduler.CyclicLR'>

Scheduler Params

To see the available params for a scheduler we can look at its corresponding dataclass:

from nemo.core.config.schedulers import CosineAnnealingParams

print(CosineAnnealingParams())
CosineAnnealingParams(last_epoch=-1, warmup_steps=None, warmup_ratio=None, min_lr=0.0)

Register scheduler

Register a new scheduler to be used with NeMo with:

nemo.core.optim.lr_scheduler.register_scheduler(name: str, scheduler: torch.optim.lr_scheduler._LRScheduler, scheduler_params: nemo.core.config.schedulers.SchedulerParams)[source]

Checks if the scheduler name exists in the registry, and if it doesnt, adds it.

This allows custom schedulers to be added and called by name during instantiation.

Parameters
  • name – Name of the optimizer. Will be used as key to retrieve the optimizer.

  • scheduler – Scheduler class (inherits from _LRScheduler)

  • scheduler_params – The parameters as a dataclass of the scheduler

Save and Restore

NeMo models all come with .save_to and .restore_from methods.

Save

To save a NeMo model:

model.save_to('/path/to/model.nemo')

Everything needed to use the trained model will be packaged and saved in the .nemo file. For example, in the NLP domain, .nemo files will include necessary tokenizer models and/or vocabulary files, etc.

Note

.nemo files are simply archives like any other .tar file.

Restore

To restore a NeMo model:

model.restore_from('/path/to/model.nemo')

When using the PyTorch Lightning Trainer, PyTorch Lightning checkpoint are created. These are mainly used within NeMo to autoresume training. Since NeMo models are LightningModules, the PyTorch Lightning method load_from_checkpoint is available. Note that load_from_checkpoint won’t necessarily work out of the box for all models as some models require more artifacts than just the checkpoint to be restored. For these models, the user will have to override load_from_checkpoint if they wish to use it.

It’s highly recommended to use restore_from to load NeMo models.

Experiment Manager

NeMo’s Experiment Manager leverages PyTorch Lightning for model checkpointing, TensorBoard Logging, and Weights and Biases logging. The Experiment Manager is included by default in all NeMo example scripts.

To use the experiment manager simply call exp_manager and pass in the PyTorch Lightning Trainer.

exp_manager(trainer, cfg.get("exp_manager", None))

And is configurable via YAML with Hydra.

exp_manager:
    exp_dir: /path/to/my/experiments
    name: my_experiment_name
    create_tensorboard_logger: True
    create_checkpoint_callback: True

Optionally launch Tensorboard to view training results in ./nemo_experiments (by default).

tensorboard --bind_all --logdir nemo_experiments

If create_checkpoint_callback is set to True then NeMo will automatically create checkpoints during training using PyTorch Lightning’s ModelCheckpoint We can configure the ModelCheckpoint via YAML or CLI.

exp_manager:
    ...
    # configure the PyTorch Lightning ModelCheckpoint using checkpoint_call_back_params
    # any ModelCheckpoint argument can be set here

    # save the best checkpoints based on this metric
    checkpoint_callback_params.monitor=val_loss

    # choose how many total checkpoints to save
    checkpoint_callback_params.save_top_k=5

We can auto-resume training as well by configuring the exp_manager. Being able to auto-resume is important when doing long training runs that are premptible or may be shut down before the training procedure has completed. To auto-resume training set the following via YAML or CLI:

exp_manager:
    ...
    # resume training if checkpoints already exist
    resume_if_exists: True

    # to start training with no existing checkpoints
    resume_ignore_no_checkpoint: True

    # by default experiments will be versioned by datetime
    # we can set our own version with
    exp_manager.version: my_experiment_version

Neural Modules

NeMo is built around Neural Modules, conceptual blocks of neural networks that take typed inputs and produce typed outputs. Such modules typically represent data layers, encoders, decoders, language models, loss functions, or methods of combining activations. NeMo makes it easy to combine and re-use these building blocks while providing a level of semantic correctness checking via its neural type system.

Note

All Neural Modules inherit from ``torch.nn.Module`` and are therefore compatible with the PyTorch ecosystem.

There are 3 types on Neural Modules:

  • Regular modules

  • Dataset/IterableDataset

  • Losses

Every Neural Module in NeMo must inherit from nemo.core.classes.module.NeuralModule class.

class nemo.core.classes.module.NeuralModule(*args: Any, **kwargs: Any)[source]

Bases: torch.nn., nemo.core.classes.common.Typing, nemo.core.classes.common.Serialization, nemo.core.classes.common.FileIO

Abstract class offering interface shared between all PyTorch Neural Modules.

as_frozen()[source]

Context manager which temporarily freezes a module, yields control and finally unfreezes the module.

freeze()None[source]

Freeze all params for inference.

input_example()[source]

Override this method if random inputs won’t work :returns: A tuple sample of valid input data.

property num_weights
unfreeze()None[source]

Unfreeze all parameters for training.

Every Neural Modules inherits the nemo.core.classes.common.Typing interface and needs to define neural types for its inputs and outputs. This is done by defining two properties: input_types and output_types. Each property should return an ordered dictionary of “port name”->”port neural type” pairs. This is the example from ConvASREncoder class:

@property
def input_types(self):
    return OrderedDict(
        {
            "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
            "length": NeuralType(tuple('B'), LengthsType()),
        }
    )

@property
def output_types(self):
    return OrderedDict(
        {
            "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
            "encoded_lengths": NeuralType(tuple('B'), LengthsType()),
        }
    )

@typecheck()
def forward(self, audio_signal, length=None):
    ...
The code snippet above means that nemo.collections.asr.modules.conv_asr.ConvASREncoder expects two arguments:
  • First one, named audio_signal of shape [batch, dimension, time] with elements representing spectrogram values.

  • Second one, named length of shape [batch] with elements representing lengths of corresponding signals.

It also means that .forward(…) and __call__(…) methods each produce two outputs:
  • First one, of shape [batch, dimension, time] but with elements representing encoded representation (AcousticEncodedRepresentation class)

  • Second one, of shape [batch], corresponding to their lengths.

Tip

It is a good practice to define types and add @typecheck() decorator to your .forward() method once your module is ready for use by others.

Note

The outputs of .forward(…) method will always be of type torch.Tensor or container of tensors and will work with any other Pytorch code. The type information will be attached to every output tensor.

If tensors without types will be passed to your module it will not fail but types will not be checked. Thus it is recommended to define input/output types for all your modules, starting with data layers and add @typecheck() decorator to them.

Note

To temporarily disable typechecking you can enclose your code in `with typecheck.disable_checks():` statement.

Neural Types

Motivation

Neural Types describe the semantics, axis order, and dimensions of a tensor. The purpose of this type system is to catch semantic and dimensionality errors during model creation and facilitate module re-use.

Neural Types Motivation

NeuralType class

Neural Types perform semantic checks for modules and models inputs/outputs. They contain information about:

  • Semantics of what is stored in the tensors. For example, logits, logprobs, audiosignal, embeddings, etc.

  • Axes layout, semantic and (optionally) dimensionality. For example: [Batch, Time, Channel]

Types are implemented in nemo.core.neural_types.NeuralType class. When you instantiate an instance of this class you are expected to include both axes information and element type information.

class nemo.core.neural_types.NeuralType(axes: Optional[Tuple] = None, elements_type: nemo.core.neural_types.elements.ElementType = VoidType, optional=False)[source]

Bases: object

This is the main class which would represent neural type concept. It is used to represent the types of inputs and outputs.

Parameters
  • axes (Optional[Tuple]) – a tuple of AxisTypes objects representing the semantics of what varying each axis means You can use a short, string-based form here. For example: (‘B’, ‘C’, ‘H’, ‘W’) would correspond to an NCHW format frequently used in computer vision. (‘B’, ‘T’, ‘D’) is frequently used for signal processing and means [batch, time, dimension/channel].

  • elements_type (ElementType) – an instance of ElementType class representing the semantics of what is stored inside the tensor. For example: logits (LogitsType), log probabilities (LogprobType), etc.

  • optional (bool) – By default, this is false. If set to True, it would means that input to the port of this type can be optional.

compare(second)nemo.core.neural_types.comparison.NeuralTypeComparisonResult[source]

Performs neural type comparison of self with second. When you chain two modules’ inputs/outputs via __call__ method, this comparison will be called to ensure neural type compatibility.

compare_and_raise_error(parent_type_name, port_name, second_object)[source]

Method compares definition of one type with another and raises an error if not compatible.

Type Comparison Results

When comparing two neural types the following comparison results can be generated.

class nemo.core.neural_types.NeuralTypeComparisonResult(value)[source]

Bases: enum.Enum

The result of comparing two neural type objects for compatibility. When comparing A.compare_to(B):

CONTAINER_SIZE_MISMATCH = 5
DIM_INCOMPATIBLE = 3
GREATER = 2
INCOMPATIBLE = 6
LESS = 1
SAME = 0
SAME_TYPE_INCOMPATIBLE_PARAMS = 7
TRANSPOSE_SAME = 4
UNCHECKED = 8

Examples

Long vs short notation

NeMo’s NeuralType class allows you to express axis semantics information in long and short form. Consider these two equivalent types. Both encoder 3 dimensional tensors and both contain elements of type AcousticEncodedRepresentation (this type is a typical output of ASR encoders).

long_version = NeuralType(
        axes=(AxisType(AxisKind.Batch, None), AxisType(AxisKind.Dimension, None), AxisType(AxisKind.Time, None)),
        elements_type=AcousticEncodedRepresentation(),
    )
short_version = NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())
assert long_version.compare(short_version) == NeuralTypeComparisonResult.SAME

Transpose same

Often it is useful to know if a simple transposition will solve type incompatibility. This is the case if the comparison result of two types equals nemo.core.neural_types.NeuralTypeComparisonResult.TRANSPOSE_SAME.

type1 = NeuralType(axes=('B', 'T', 'C'))
type2 = NeuralType(axes=('T', 'B', 'C'))
assert type1.compare(type2) == NeuralTypeComparisonResult.TRANSPOSE_SAME
assert type2.compare(type1) == NeuralTypeComparisonResult.TRANSPOSE_SAME

Note that in this example, we dropped elements_type argument of NeuralType constructor. If not supplied, the element type would be VoidType.

VoidType for elements

Sometimes it is useful to express that elements’ types don’t matter but axes layout does. VoidType for elements can be used to express this.

Note

VoidType is compatible with every other elements’ type but not the other way around. See code snippet below for details.

btc_spctr = NeuralType(('B', 'T', 'C'), SpectrogramType())
btc_spct_bad = NeuralType(('B', 'T'), SpectrogramType())
# Note the VoidType for elements here
btc_void = NeuralType(('B', 'T', 'C'), VoidType())

# This is true because VoidType is compatible with every other element type (SpectrogramType in this case)
# And axes layout between btc_void and btc_spctr is the same
assert btc_void.compare(btc_spctr) == NeuralTypeComparisonResult.SAME
# These two types are incompatible because even though VoidType is used for elements on one side,
# the axes layout is different
assert btc_void.compare(btc_spct_bad) == NeuralTypeComparisonResult.INCOMPATIBLE
# Note that even though VoidType is compatible with every other type, other types are not compatible with VoidType!
# It is one-way compatibility
assert btc_spctr.compare(btc_void) == NeuralTypeComparisonResult.INCOMPATIBLE

Element type inheritance

Neural types in NeMo support Python inheritance between element types. Consider an example where you want to develop a Neural Module which performs data augmentation for all kinds of spectrograms. In ASR two types of spectrograms are frequently used: mel and mfcc. To express this we will create two 3 classes to express element’s types: SpectrogramType, MelSpectrogramType(SpectrogramType), MFCCSpectrogramType(SpectrogramType).

input = NeuralType(('B', 'D', 'T'), SpectrogramType())
out1 = NeuralType(('B', 'D', 'T'), MelSpectrogramType())
out2 = NeuralType(('B', 'D', 'T'), MFCCSpectrogramType())

# MelSpectrogram and MFCCSpectrogram are not interchangeable.
assert out1.compare(out2) == NeuralTypeComparisonResult.INCOMPATIBLE
assert out2.compare(out1) == NeuralTypeComparisonResult.INCOMPATIBLE
# Type comparison detects that MFCC/MelSpectrogramType is a kind of SpectrogramType and can be accepted.
assert input.compare(out1) == NeuralTypeComparisonResult.GREATER
assert input.compare(out2) == NeuralTypeComparisonResult.GREATER

Custom element types

It is possible to create user-defined element types to express the semantics of elements in your tensors. To do so, you need to inherit and implement abstract methods of nemo.core.neural_types.elements.ElementType class

class nemo.core.neural_types.elements.ElementType[source]

Bases: abc.ABC

Abstract class defining semantics of the tensor elements. We are relying on Python for inheritance checking

compare(second)nemo.core.neural_types.comparison.NeuralTypeComparisonResult[source]
property fields

This should be used to logically represent tuples/structures. For example, if you want to represent a bounding box (x, y, width, height) you can put a tuple with names (‘x’, y’, ‘w’, ‘h’) in here. Under the hood this should be converted to the last tesnor dimension of fixed size = len(fields). When two types are compared their fields must match.

property type_parameters

Override this property to parametrize your type. For example, you can specify ‘storage’ type such as float, int, bool with ‘dtype’ keyword. Another example, is if you want to represent a signal with a particular property (say, sample frequency), then you can put sample_freq->value in there. When two types are compared their type_parameters must match.

Note that element types can be parametrized. Consider this example where it distinguishes between audio sampled at 8Khz and 16Khz.

audio16K = NeuralType(axes=('B', 'T'), elements_type=AudioSignal(16000))
audio8K = NeuralType(axes=('B', 'T'), elements_type=AudioSignal(8000))

assert audio8K.compare(audio16K) == NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS
assert audio16K.compare(audio8K) == NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS

Enforcing dimensions

In addition to specifying tensor layout and elements’ semantics, neural types also allow you to enforce tensor dimensions. You will have to use long notation to specify dimensions. Short notation only allows you to specify axes semantics and assumes arbitrary dimensions.

type1 = NeuralType(
(AxisType(AxisKind.Batch, 64), AxisType(AxisKind.Time, 10), AxisType(AxisKind.Dimension, 128)),
SpectrogramType(),
)
type2 = NeuralType(('B', 'T', 'C'), SpectrogramType())

# type2 will accept elements of type1 because their axes semantics match and type2 does not care about dimensions
assert type2.compare(type1), NeuralTypeComparisonResult.SAME
# type1 will not accept elements of type2 because it need dimensions to match strictly.
assert type1.compare(type2), NeuralTypeComparisonResult.DIM_INCOMPATIBLE

Generic Axis kind

Sometimes (especially in the case of loss modules) it is useful to be able to specify a “generic” axis kind which will make it compatible with any other kind of axis. This is easy to express with Neural Types by using nemo.core.neural_types.axes.AxisKind.Any for axes.

type1 = NeuralType(('B', 'Any', 'Any'), SpectrogramType())
type2 = NeuralType(('B', 'T', 'C'), SpectrogramType())
type3 = NeuralType(('B', 'C', 'T'), SpectrogramType())

# type1 will accept elements of type2 and type3 because it only cares about element kind (SpectrogramType)
# number of axes (3) and that first one corresponds to batch
assert type1.compare(type2) == NeuralTypeComparisonResult.SAME
assert type1.compare(type3) == NeuralTypeComparisonResult.INCOMPATIBLE

Container types

NeMo type system understands Python containers (lists). If your module returns a nested list of typed tensors, the way to express it is by using Python list notation and Neural Types together when defining your input/output types:

An example below shows how to express that your module returns single output (“out”) which is list of lists of two dimensional tensors of shape [batch, dimension] containing logits.

@property
def output_types(self):
    return {
        "out": [[NeuralType(('B', 'D'), LogitsType())]],
    }

Core APIs

Base class for all NeMo models

class nemo.core.ModelPT(*args: Any, **kwargs: Any)[source]

Bases: pytorch_lightning., nemo.core.classes.common.Model

Interface for Pytorch-lightning based NeMo models

register_artifact(config_path: str, src: str)[source]

Register model artifacts with this function. These artifacts (files) will be included inside .nemo file when model.save_to(“mymodel.nemo”) is called.

WARNING: If you specified /example_folder/example.txt but ./example.txt exists, then ./example.txt will be used.

Parameters
  • config_path – config path where artifact is used

  • src – path to the artifact

Returns

path to be used when accessing artifact. If src=’’ or None then ‘’ or None will be returned

save_to(save_path: str)
Saves model instance (weights and configuration) into .nemo file

You can use “restore_from” method to fully restore instance from .nemo file.

.nemo file is an archive (tar.gz) with the following:

model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model’s constructor model_wights.chpt - model checkpoint

Parameters

save_path – Path to .nemo file where model instance should be saved

classmethod restore_from(restore_path: str, override_config_path: Optional[Union[omegaconf.OmegaConf, str]] = None, map_location: Optional[torch.device] = None, strict: bool = False, return_config: bool = False)[source]

Restores model instance (weights and configuration) from .nemo file.

Parameters
  • restore_path – path to .nemo file from which model should be instantiated

  • override_config_path – path to a yaml config that will override the internal config file or an OmegaConf / DictConfig object representing the model config.

  • map_location – Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise.

  • strict – Passed to load_state_dict.

  • return_config – If set to true, will return just the underlying config of the restored model as an OmegaConf DictConfig object without instantiating the model.

  • Example

    ` model = nemo.collections.asr.models.EncDecCTCModel.restore_from('asr.nemo') assert isinstance(model, nemo.collections.asr.models.EncDecCTCModel) `

Returns

An instance of type cls or its underlying config (if return_config is set).

classmethod load_from_checkpoint(checkpoint_path: str, *args, map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, hparams_file: Optional[str] = None, strict: bool = True, **kwargs)[source]

Loads ModelPT from checkpoint, with some maintenance of restoration. For documentation, please refer to LightningModule.load_from_checkpoin() documentation.

abstract setup_training_data(train_data_config: Union[omegaconf.DictConfig, Dict])[source]

Setups data loader to be used in training

Parameters

train_data_layer_config – training data layer parameters.

Returns:

abstract setup_validation_data(val_data_config: Union[omegaconf.DictConfig, Dict])[source]

Setups data loader to be used in validation :param val_data_layer_config: validation data layer parameters.

Returns:

setup_test_data(test_data_config: Union[omegaconf.DictConfig, Dict])[source]

(Optionally) Setups data loader to be used in test

Parameters

test_data_layer_config – test data layer parameters.

Returns:

setup_multiple_validation_data(val_data_config: Union[omegaconf.DictConfig, Dict])[source]

(Optionally) Setups data loader to be used in validation, with support for multiple data loaders.

Parameters

val_data_layer_config – validation data layer parameters.

setup_multiple_test_data(test_data_config: Union[omegaconf.DictConfig, Dict])[source]

(Optionally) Setups data loader to be used in test, with support for multiple data loaders.

Parameters

test_data_layer_config – test data layer parameters.

setup_optimization(optim_config: Optional[Union[omegaconf.DictConfig, Dict]] = None)[source]

Prepares an optimizer from a string name and its optional config parameters.

Parameters

optim_config

A dictionary containing the following keys:

  • ”lr”: mandatory key for learning rate. Will raise ValueError if not provided.

  • ”optimizer”: string name pointing to one of the available optimizers in the registry. If not provided, defaults to “adam”.

  • ”opt_args”: Optional list of strings, in the format “arg_name=arg_value”. The list of “arg_value” will be parsed and a dictionary of optimizer kwargs will be built and supplied to instantiate the optimizer.

configure_optimizers()[source]
train_dataloader()[source]
val_dataloader()[source]
test_dataloader()[source]
validation_epoch_end(outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]])Optional[Dict[str, Dict[str, torch.Tensor]]][source]

Default DataLoader for Validation set which automatically supports multiple data loaders via multi_validation_epoch_end.

If multi dataset support is not required, override this method entirely in base class. In such a case, there is no need to implement multi_validation_epoch_end either.

Note

If more than one data loader exists, and they all provide val_loss, only the val_loss of the first data loader will be used by default. This default can be changed by passing the special key val_dl_idx: int inside the validation_ds config.

Parameters

outputs – Single or nested list of tensor outputs from one or more data loaders.

Returns

A dictionary containing the union of all items from individual data_loaders, along with merged logs from all data loaders.

test_epoch_end(outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]])Optional[Dict[str, Dict[str, torch.Tensor]]][source]

Default DataLoader for Test set which automatically supports multiple data loaders via multi_test_epoch_end.

If multi dataset support is not required, override this method entirely in base class. In such a case, there is no need to implement multi_test_epoch_end either.

Note

If more than one data loader exists, and they all provide test_loss, only the test_loss of the first data loader will be used by default. This default can be changed by passing the special key test_dl_idx: int inside the test_ds config.

Parameters

outputs – Single or nested list of tensor outputs from one or more data loaders.

Returns

A dictionary containing the union of all items from individual data_loaders, along with merged logs from all data loaders.

multi_validation_epoch_end(outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0)Optional[Dict[str, Dict[str, torch.Tensor]]][source]

Adds support for multiple validation datasets. Should be overriden by subclass, so as to obtain appropriate logs for each of the dataloaders.

Parameters
  • outputs – Same as that provided by LightningModule.validation_epoch_end() for a single dataloader.

  • dataloader_idx – int representing the index of the dataloader.

Returns

A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be pre-pended by the dataloader prefix.

multi_test_epoch_end(outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0)Optional[Dict[str, Dict[str, torch.Tensor]]][source]

Adds support for multiple test datasets. Should be overriden by subclass, so as to obtain appropriate logs for each of the dataloaders.

Parameters
  • outputs – Same as that provided by LightningModule.validation_epoch_end() for a single dataloader.

  • dataloader_idx – int representing the index of the dataloader.

Returns

A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be pre-pended by the dataloader prefix.

get_validation_dataloader_prefix(dataloader_idx: int = 0)str[source]

Get the name of one or more data loaders, which will be prepended to all logs.

Parameters

dataloader_idx – Index of the data loader.

Returns

str name of the data loader at index provided.

get_test_dataloader_prefix(dataloader_idx: int = 0)str[source]

Get the name of one or more data loaders, which will be prepended to all logs.

Parameters

dataloader_idx – Index of the data loader.

Returns

str name of the data loader at index provided.

classmethod extract_state_dict_from(restore_path: str, save_dir: str, split_by_module: bool = False)[source]

Extract the state dict(s) from a provided .nemo tarfile and save it to a directory.

Parameters
  • restore_path – path to .nemo file from which state dict(s) should be extracted

  • save_dir – directory in which the saved state dict(s) should be stored

  • split_by_module – bool flag, which determins whether the output checkpoint should be for the entire Model, or the individual module’s that comprise the Model

Example

To convert the .nemo tarfile into a single Model level PyTorch checkpoint :: state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from(‘asr.nemo’, ‘./asr_ckpts’)

To restore a model from a Model level checkpoint :: model = nemo.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration model.load_state_dict(torch.load(“./asr_ckpts/model_weights.ckpt”))

To convert the .nemo tarfile into multiple Module level PyTorch checkpoints :: state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from(‘asr.nemo’, ‘./asr_ckpts’, split_by_module=True)

To restore a module from a Module level checkpoint :: model = nemo.collections.asr.models.EncDecCTCModel(cfg) # or any other method of restoration

# load the individual components model.preprocessor.load_state_dict(torch.load(“./asr_ckpts/preprocessor.ckpt”)) model.encoder.load_state_dict(torch.load(“./asr_ckpts/encoder.ckpt”)) model.decoder.load_state_dict(torch.load(“./asr_ckpts/decoder.ckpt”))

Returns

The state dict that was loaded from the original .nemo checkpoint

prepare_test(trainer: pytorch_lightning.Trainer)bool[source]

Helper method to check whether the model can safely be tested on a dataset after training (or loading a checkpoint).

trainer = Trainer()
if model.prepare_test(trainer):
    trainer.test(model)
Returns

bool which declares the model safe to test. Provides warnings if it has to return False to guide the user.

set_trainer(trainer: pytorch_lightning.Trainer)[source]

Set an instance of Trainer object.

Parameters

trainer – PyTorch Lightning Trainer object.

set_world_size(trainer: pytorch_lightning.Trainer)[source]

Determines the world size from the PyTorch Lightning Trainer. And then updates AppState.

Parameters

trainer (Trainer) – PyTorch Lightning Trainer object

property num_weights

Utility property that returns the total number of parameters of the Model.

property cfg

Property that holds the finalized internal config of the model.

Note

Changes to this config are not reflected in the state of the model. Please create a new model using an updated config to properly update the model.

Base Neural Module class

class nemo.core.NeuralModule(*args: Any, **kwargs: Any)[source]

Bases: torch.nn., nemo.core.classes.common.Typing, nemo.core.classes.common.Serialization, nemo.core.classes.common.FileIO

Abstract class offering interface shared between all PyTorch Neural Modules.

property num_weights
input_example()[source]

Override this method if random inputs won’t work :returns: A tuple sample of valid input data.

freeze()None[source]

Freeze all params for inference.

unfreeze()None[source]

Unfreeze all parameters for training.

as_frozen()[source]

Context manager which temporarily freezes a module, yields control and finally unfreezes the module.

Neural Type classes

class nemo.core.neural_types.NeuralType(axes: Optional[Tuple] = None, elements_type: nemo.core.neural_types.elements.ElementType = VoidType, optional=False)[source]

Bases: object

This is the main class which would represent neural type concept. It is used to represent the types of inputs and outputs.

Parameters
  • axes (Optional[Tuple]) – a tuple of AxisTypes objects representing the semantics of what varying each axis means You can use a short, string-based form here. For example: (‘B’, ‘C’, ‘H’, ‘W’) would correspond to an NCHW format frequently used in computer vision. (‘B’, ‘T’, ‘D’) is frequently used for signal processing and means [batch, time, dimension/channel].

  • elements_type (ElementType) – an instance of ElementType class representing the semantics of what is stored inside the tensor. For example: logits (LogitsType), log probabilities (LogprobType), etc.

  • optional (bool) – By default, this is false. If set to True, it would means that input to the port of this type can be optional.

compare(second)nemo.core.neural_types.comparison.NeuralTypeComparisonResult[source]

Performs neural type comparison of self with second. When you chain two modules’ inputs/outputs via __call__ method, this comparison will be called to ensure neural type compatibility.

compare_and_raise_error(parent_type_name, port_name, second_object)[source]

Method compares definition of one type with another and raises an error if not compatible.

class nemo.core.neural_types.axes.AxisType(kind: nemo.core.neural_types.axes.AxisKindAbstract, size: Optional[int] = None, is_list=False)[source]

Bases: object

This class represents axis semantics and (optionally) it’s dimensionality :param kind: what kind of axis it is? For example Batch, Height, etc. :type kind: AxisKindAbstract :param size: specify if the axis should have a fixed size. By default it is set to None and you :type size: int, optional :param typically do not want to set it for Batch and Time: :param is_list: whether this is a list or a tensor axis :type is_list: bool, default=False

class nemo.core.neural_types.elements.ElementType[source]

Bases: abc.ABC

Abstract class defining semantics of the tensor elements. We are relying on Python for inheritance checking

property type_parameters

Override this property to parametrize your type. For example, you can specify ‘storage’ type such as float, int, bool with ‘dtype’ keyword. Another example, is if you want to represent a signal with a particular property (say, sample frequency), then you can put sample_freq->value in there. When two types are compared their type_parameters must match.

property fields

This should be used to logically represent tuples/structures. For example, if you want to represent a bounding box (x, y, width, height) you can put a tuple with names (‘x’, y’, ‘w’, ‘h’) in here. Under the hood this should be converted to the last tesnor dimension of fixed size = len(fields). When two types are compared their fields must match.

compare(second)nemo.core.neural_types.comparison.NeuralTypeComparisonResult[source]
class nemo.core.neural_types.comparison.NeuralTypeComparisonResult(value)[source]

Bases: enum.Enum

The result of comparing two neural type objects for compatibility. When comparing A.compare_to(B):

SAME = 0
LESS = 1
GREATER = 2
DIM_INCOMPATIBLE = 3
TRANSPOSE_SAME = 4
CONTAINER_SIZE_MISMATCH = 5
INCOMPATIBLE = 6
SAME_TYPE_INCOMPATIBLE_PARAMS = 7
UNCHECKED = 8

Experiment manager

class nemo.utils.exp_manager.exp_manager(trainer: pytorch_lightning.Trainer, cfg: Optional[Union[omegaconf.DictConfig, Dict]] = None)[source]

exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will get exp_dir, name, and version from the logger. Otherwise it will use the exp_dir and name arguments to create the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir. The version will be a datetime string or an integer. Note, exp_manager does not handle versioning on slurm multi-node runs. Datestime version can be disabled if use_datetime_version is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file for each process to log their output into. exp_manager additionally has a resume feature which can be used to continuing training from the constructed log_dir.

Parameters
  • trainer (pytorch_lightning.Trainer) – The lightning trainer.

  • cfg (DictConfig, dict) –

    Can have the following keys: - explicit_log_dir (str, Path): Can be used to override exp_dir/name/version folder creation. Defaults to

    None, which will use exp_dir, name, and version to construct the logging directory.

    • exp_dir (str, Path): The base directory to create the logging directory. Defaults to None, which logs to

      ./nemo_experiments.

    • name (str): The name of the experiment. Defaults to None which turns into “default” via name = name or

      ”default”.

    • version (str): The version of the experiment. Defaults to None which uses either a datetime string or

      lightning’s TensorboardLogger system of using version_{int}.

    • use_datetime_version (bool): Whether to use a datetime string for version. Defaults to True.

    • resume_if_exists (bool): Whether this experiment is resuming from a previous run. If True, it sets

      trainer.resume_from_checkpoint so that the trainer should auto-resume. exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False.

    • resume_past_end (bool): exp_manager errors out if resume_if_exists is True and a checkpoint matching

      *end.ckpt indicating a previous training run fully completed. This behaviour can be disabled, in which case the *end.ckpt will be loaded by setting resume_past_end to True. Defaults to False.

    • resume_ignore_no_checkpoint (bool): exp_manager errors out if resume_if_exists is True and no checkpoint

      could be found. This behaviour can be disabled, in which case exp_manager will print a message and continue without restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False.

    • create_tensorboard_logger (bool): Whether to create a tensorboard logger and attach it to the pytorch

      lightning trainer. Defaults to True.

    • summary_writer_kwargs (dict): A dictionary of kwargs that can be passed to lightning’s TensorboardLogger

      class. Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None.

    • create_wandb_logger (bool): Whether to create a Weights and Baises logger and attach it to the pytorch

      lightning trainer. Defaults to False.

    • wandb_logger_kwargs (dict): A dictionary of kwargs that can be passed to lightning’s WandBLogger

      class. Note that name and project are required parameters if create_wandb_logger is True. Defaults to None.

    • create_checkpoint_callback (bool): Whether to create a ModelCheckpoint callback and attach it to the

      pytorch lightning trainer. The ModelCheckpoint saves the top 3 models with the best “val_loss”, the most recent checkpoint under *last.ckpt, and the final checkpoint after training completes under *end.ckpt. Defaults to True.

    • files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which

      copies no files.

Returns

The final logging directory where logging files are saved. Usually the concatenation of

exp_dir, name, and version.

Return type

log_dir (Path)