Source code for nemo.collections.asr.models.ssl_models

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from math import ceil
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn
from lightning.pytorch import Trainer
from omegaconf import DictConfig

from nemo.collections.asr.data import audio_to_text_dataset, ssl_dataset
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
from nemo.collections.asr.modules.ssl_modules.masking import ConvFeatureMaksingWrapper
from nemo.collections.asr.parts.mixins import ASRModuleMixin
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.data.utils import move_data_to_device
from nemo.collections.common.parts.preprocessing.parsers import make_parser
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.classes.mixins import AccessMixin, set_access_cfg
from nemo.core.neural_types import (
    AcousticEncodedRepresentation,
    AudioSignal,
    LabelsType,
    LengthsType,
    LogprobsType,
    NeuralType,
    SpectrogramType,
)
from nemo.utils import logging

__all__ = ['SpeechEncDecSelfSupervisedModel', 'EncDecMaskedTokenPredModel', 'EncDecDenoiseMaskedTokenPredModel']


[docs] class SpeechEncDecSelfSupervisedModel(ModelPT, ASRModuleMixin, AccessMixin): """Base class for encoder-decoder models used for self-supervised encoder pre-training"""
[docs] @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ results = [] model = PretrainedModelInfo( pretrained_model_name="ssl_en_conformer_large", description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_large", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_large/versions/1.10.1/files/ssl_en_conformer_large.nemo", ) results.append(model) model = PretrainedModelInfo( pretrained_model_name="ssl_en_conformer_xlarge", description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_xlarge", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_xlarge/versions/1.10.0/files/ssl_en_conformer_xlarge.nemo", ) results.append(model) return results
def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 self.world_size = 1 if trainer is not None: self.world_size = trainer.world_size super().__init__(cfg=cfg, trainer=trainer) self.preprocessor = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.preprocessor) self.encoder = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.encoder) self.decoder_losses = None if "loss_list" in self._cfg: self.decoder_losses = {} self.loss_alphas = {} self.start_step = {} self.output_from_layer = {} self.transpose_encoded = {} self.targets_from_loss = {} self.decoder_losses_active = {} # need to be separate for moduledict for decoder_loss_name, decoder_loss_cfg in self._cfg.loss_list.items(): if not decoder_loss_cfg.get("is_active", True): # active by default continue new_decoder_loss = { 'decoder': SpeechEncDecSelfSupervisedModel.from_config_dict(decoder_loss_cfg.decoder), 'loss': SpeechEncDecSelfSupervisedModel.from_config_dict(decoder_loss_cfg.loss), } new_decoder_loss = nn.ModuleDict(new_decoder_loss) self.decoder_losses[decoder_loss_name] = new_decoder_loss self.loss_alphas[decoder_loss_name] = decoder_loss_cfg.get("loss_alpha", 1.0) self.output_from_layer[decoder_loss_name] = decoder_loss_cfg.get("output_from_layer", None) self.targets_from_loss[decoder_loss_name] = decoder_loss_cfg.get("targets_from_loss", None) self.start_step[decoder_loss_name] = decoder_loss_cfg.get("start_step", 0) self.transpose_encoded[decoder_loss_name] = decoder_loss_cfg.get("transpose_encoded", False) self.decoder_losses_active[decoder_loss_name] = True self.decoder_losses = nn.ModuleDict(self.decoder_losses) else: self.decoder_ssl = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.decoder) self.loss = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.loss) self.spec_augmentation = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.spec_augment) # dropout for features/spectrograms (applied before masking) self.dropout_features = ( torch.nn.Dropout(self._cfg.dropout_features) if "dropout_features" in self._cfg else None ) # dropout for targets (applied before quantization) self.dropout_features_q = ( torch.nn.Dropout(self._cfg.dropout_features_q) if "dropout_features_q" in self._cfg else None ) # Feature penalty for preprocessor encodings (for Wav2Vec training) if "feature_penalty" in self._cfg: self.feat_pen, self.pen_factor = 0.0, self._cfg.feature_penalty else: self.feat_pen, self.pen_factor = None, None if "access" in self._cfg: set_access_cfg(self._cfg.access, self.model_guid) self.apply_masking = True def _setup_dataloader_from_config(self, config: Optional[Dict]): if 'augmentor' in config: augmentor = process_augmentations(config['augmentor']) else: augmentor = None # Automatically inject args from model config to dataloader config audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate') if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, global_rank=self.global_rank, world_size=self.world_size, dataset=LhotseSpeechToTextBpeDataset( tokenizer=make_parser( labels=config.get('labels', None), name=config.get('parser', 'en'), unk_id=config.get('unk_index', -1), blank_id=config.get('blank_index', -1), do_normalize=config.get('normalize_transcripts', False), ), ), ) shuffle = config['shuffle'] device = 'gpu' if torch.cuda.is_available() else 'cpu' if config.get('use_dali', False): device_id = self.local_rank if device == 'gpu' else None dataset = audio_to_text_dataset.get_dali_char_dataset( config=config, shuffle=shuffle, device_id=device_id, global_rank=self.global_rank, world_size=self.world_size, preprocessor_cfg=self._cfg.preprocessor, ) return dataset # Instantiate tarred dataset loader or normal dataset loader if config.get('is_tarred', False): if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or ( 'manifest_filepath' in config and config['manifest_filepath'] is None ): logging.warning( "Could not load dataset as `manifest_filepath` was None or " f"`tarred_audio_filepaths` is None. Provided config : {config}" ) return None shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 dataset = audio_to_text_dataset.get_tarred_dataset( config=config, shuffle_n=shuffle_n, global_rank=self.global_rank, world_size=self.world_size, augmentor=augmentor, ) shuffle = False else: if 'manifest_filepath' in config and config['manifest_filepath'] is None: logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") return None dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor) if hasattr(dataset, 'collate_fn'): collate_fn = dataset.collate_fn elif hasattr(dataset.datasets[0], 'collate_fn'): # support datasets that are lists of entries collate_fn = dataset.datasets[0].collate_fn else: # support datasets that are lists of lists collate_fn = dataset.datasets[0].datasets[0].collate_fn return torch.utils.data.DataLoader( dataset=dataset, batch_size=config['batch_size'], collate_fn=collate_fn, drop_last=config.get('drop_last', False), shuffle=shuffle, num_workers=config.get('num_workers', 0), pin_memory=config.get('pin_memory', False), )
[docs] def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): """ Sets up the training data loader via a Dict-like object. Args: train_data_config: A config that contains the information regarding construction of an ASR Training dataset. Supported Datasets: - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` """ if 'shuffle' not in train_data_config: train_data_config['shuffle'] = True # preserve config self._update_dataset_config(dataset_name='train', config=train_data_config) self._train_dl = self._setup_dataloader_from_config(config=train_data_config) # Need to set this because if using an IterableDataset, the length of the dataloader is the total number # of samples rather than the number of batches, and this messes up the tqdm progress bar. # So we set the number of steps manually (to the correct number) to fix this. if ( self._train_dl is not None and hasattr(self._train_dl, 'dataset') and isinstance(self._train_dl.dataset, torch.utils.data.IterableDataset) ): # We also need to check if limit_train_batches is already set. # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): self._trainer.limit_train_batches = int( self._trainer.limit_train_batches * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) ) elif self._trainer is None: logging.warning( "Model Trainer was not set before constructing the dataset, incorrect number of " "training batches will be used. Please set the trainer and rebuild the dataset." )
[docs] def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): """ Sets up the validation data loader via a Dict-like object. Args: val_data_config: A config that contains the information regarding construction of an ASR Training dataset. Supported Datasets: - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` """ if 'shuffle' not in val_data_config: val_data_config['shuffle'] = False # preserve config self._update_dataset_config(dataset_name='validation', config=val_data_config) self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) # Need to set this because if using an IterableDataset, the length of the dataloader is the total number # of samples rather than the number of batches, and this messes up the tqdm progress bar. # So we set the number of steps manually (to the correct number) to fix this. if ( self._validation_dl is not None and hasattr(self._validation_dl, 'dataset') and isinstance(self._validation_dl.dataset, torch.utils.data.IterableDataset) ): # We also need to check if limit_train_batches is already set. # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). if isinstance(self._trainer.limit_val_batches, float): self._trainer.limit_val_batches = int( self._trainer.limit_val_batches * ceil((len(self._validation_dl.dataset) / self.world_size) / val_data_config['batch_size']) )
@property def input_types(self) -> Optional[Dict[str, NeuralType]]: if hasattr(self.preprocessor, '_sample_rate'): input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) else: input_signal_eltype = AudioSignal() return { "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "targets": NeuralType(('B', 'T'), LabelsType(), optional=True), "target_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), } @property def output_types(self) -> Optional[Dict[str, NeuralType]]: return { "spectrograms": NeuralType(('B', 'D', 'T'), SpectrogramType()), "spec_masks": NeuralType(('B', 'D', 'T'), SpectrogramType()), "encoded": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_len": NeuralType(tuple('B'), LengthsType()), }
[docs] @typecheck() def forward( self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, ): """ Forward pass of the model. Args: input_signal: Tensor that represents a batch of raw audio signals, of shape [B, T]. T here represents timesteps, with 1 second of audio represented as `self.sample_rate` number of floating point values. input_signal_length: Vector of length B, that contains the individual lengths of the audio sequences. processed_signal: Tensor that represents a batch of processed audio signals, of shape (B, D, T) that has undergone processing via some DALI preprocessor. processed_signal_length: Vector of length B, that contains the individual lengths of the processed audio sequences. Returns: A tuple of 4 elements - 1) Processed spectrograms of shape [B, D, T]. 2) Masks applied to spectrograms of shape [B, D, T]. 3) The encoded features tensor of shape [B, D, T]. 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. """ # Reset access registry if self.is_access_enabled(self.model_guid): self.reset_registry() # Check for special flag for validation step if hasattr(self, '_in_validation_step'): in_validation_step = self._in_validation_step else: in_validation_step = False # reset module registry from AccessMixin if ( (self.training or in_validation_step) and self.decoder_losses is not None and self.output_from_layer is not None and len(self.output_from_layer) > 0 ): layer_names = list(self.output_from_layer.values()) register_layer = any([name is not None for name in layer_names]) if register_layer: self.access_cfg['save_encoder_tensors'] = True self.set_access_enabled(access_enabled=True, guid=self.model_guid) has_input_signal = input_signal is not None and input_signal_length is not None has_processed_signal = processed_signal is not None and processed_signal_length is not None if (has_input_signal ^ has_processed_signal) == False: raise ValueError( f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " " with ``processed_signal`` and ``processed_signal_len`` arguments." ) if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( input_signal=input_signal, length=input_signal_length, ) if self.pen_factor: self.feat_pen = processed_signal.float().pow(2).mean() * self.pen_factor spectrograms = processed_signal.detach().clone() if self.dropout_features: processed_signal = self.dropout_features(processed_signal) if self.dropout_features_q: spectrograms = self.dropout_features_q(spectrograms) if self.apply_masking: processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) masked_spectrograms = processed_signal.detach() spec_masks = torch.logical_and(masked_spectrograms < 1e-5, masked_spectrograms > -1e-5).float() for idx, proc_len in enumerate(processed_signal_length): spec_masks[idx, :, proc_len:] = 0.0 encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) return spectrograms, spec_masks, encoded, encoded_len
[docs] def decoder_loss_step(self, spectrograms, spec_masks, encoded, encoded_len, targets=None, target_lengths=None): """ Forward pass through all decoders and calculate corresponding losses. Args: spectrograms: Processed spectrograms of shape [B, D, T]. spec_masks: Masks applied to spectrograms of shape [B, D, T]. encoded: The encoded features tensor of shape [B, D, T]. encoded_len: The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. targets: Optional target labels of shape [B, T] target_lengths: Optional target label lengths of shape [B] Returns: A tuple of 2 elements - 1) Total sum of losses weighted by corresponding loss_alphas 2) Dictionary of unweighted losses """ loss_val_dict = {} if self.decoder_losses is None: if hasattr(self.decoder_ssl, "needs_labels") and self.decoder_ssl.needs_labels: outputs = self.decoder_ssl(encoder_output=encoded, targets=targets, target_lengths=target_lengths) else: outputs = self.decoder_ssl(encoder_output=encoded) if self.loss.needs_labels: loss_value = self.loss( spec_masks=spec_masks, decoder_outputs=outputs, targets=targets, decoder_lengths=encoded_len, target_lengths=target_lengths, ) else: loss_value = self.loss(spectrograms=spectrograms, spec_masks=spec_masks, decoder_outputs=outputs) else: loss_value = encoded.new_zeros(1) outputs = {} registry = self.get_module_registry(self.encoder) for dec_loss_name, dec_loss in self.decoder_losses.items(): # loop through decoders and corresponding losses if not self.decoder_losses_active[dec_loss_name]: continue if self.output_from_layer[dec_loss_name] is None: dec_input = encoded else: # extract output from specified layer using AccessMixin registry dec_input = registry[self.output_from_layer[dec_loss_name]]['encoder'][-1] if self.transpose_encoded[dec_loss_name]: dec_input = dec_input.transpose(-2, -1) if self.targets_from_loss[dec_loss_name] is not None: # extract targets from specified loss target_loss = self.targets_from_loss[dec_loss_name] targets = self.decoder_losses[target_loss]['loss'].target_ids target_lengths = self.decoder_losses[target_loss]['loss'].target_lengths if target_lengths is None: target_lengths = encoded_len if hasattr(dec_loss['decoder'], "needs_labels") and dec_loss['decoder'].needs_labels: # if we are using a decoder which needs labels, provide them outputs[dec_loss_name] = dec_loss['decoder']( encoder_output=dec_input, targets=targets, target_lengths=target_lengths ) else: outputs[dec_loss_name] = dec_loss['decoder'](encoder_output=dec_input) current_loss = dec_loss['loss'] if current_loss.needs_labels: # if we are using a loss which needs labels, provide them current_loss_value = current_loss( spec_masks=spec_masks, decoder_outputs=outputs[dec_loss_name], targets=targets, decoder_lengths=encoded_len, target_lengths=target_lengths, ) else: current_loss_value = current_loss( spectrograms=spectrograms, spec_masks=spec_masks, decoder_outputs=outputs[dec_loss_name], decoder_lengths=encoded_len, ) loss_value = loss_value + current_loss_value * self.loss_alphas[dec_loss_name] loss_val_dict[dec_loss_name] = current_loss_value return loss_value, loss_val_dict
# PTL-specific methods
[docs] def training_step(self, batch, batch_nb): signal, signal_len, targets, target_lengths = batch if isinstance(batch, DALIOutputs) and batch.has_processed_signal: spectrograms, spec_masks, encoded, encoded_len = self.forward( processed_signal=signal, processed_signal_length=signal_len, ) else: spectrograms, spec_masks, encoded, encoded_len = self.forward( input_signal=signal, input_signal_length=signal_len, ) if self.decoder_losses is not None: for dec_loss_name, dec_loss in self.decoder_losses.items(): self.decoder_losses_active[dec_loss_name] = self.trainer.global_step >= self.start_step[dec_loss_name] loss = dec_loss['loss'] if hasattr(loss, "set_num_updates"): loss.set_num_updates(self.trainer.global_step) else: if hasattr(self.loss, "set_num_updates"): self.loss.set_num_updates(self.trainer.global_step) loss_value, loss_val_dict = self.decoder_loss_step( spectrograms, spec_masks, encoded, encoded_len, targets, target_lengths ) tensorboard_logs = { 'learning_rate': self._optimizer.param_groups[0]['lr'], 'global_step': self.trainer.global_step, } for loss_name, loss_val in loss_val_dict.items(): tensorboard_logs['train_' + loss_name] = loss_val if self.feat_pen: loss_value += self.feat_pen # Reset access registry self.reset_registry() return {'loss': loss_value, 'log': tensorboard_logs}
[docs] def validation_pass(self, batch, batch_idx, dataloader_idx=0): # Set flag to register tensors self._in_validation_step = True signal, signal_len, targets, target_lengths = batch if isinstance(batch, DALIOutputs) and batch.has_processed_signal: spectrograms, spec_masks, encoded, encoded_len = self.forward( processed_signal=signal, processed_signal_length=signal_len, ) else: spectrograms, spec_masks, encoded, encoded_len = self.forward( input_signal=signal, input_signal_length=signal_len, ) if self.decoder_losses is not None: for dec_loss_name, dec_loss in self.decoder_losses.items(): self.decoder_losses_active[dec_loss_name] = self.trainer.global_step >= self.start_step[dec_loss_name] loss_value, _ = self.decoder_loss_step(spectrograms, spec_masks, encoded, encoded_len, targets, target_lengths) if self.feat_pen: loss_value += self.feat_pen # reset access registry self.reset_registry() del self._in_validation_step metrics = {'val_loss': loss_value} return metrics
[docs] def validation_step(self, batch, batch_idx, dataloader_idx=0): metrics = self.validation_pass(batch, batch_idx, dataloader_idx) if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: self.validation_step_outputs[dataloader_idx].append(metrics) else: self.validation_step_outputs.append(metrics) return metrics
[docs] def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'val_loss': val_loss_mean} return {'val_loss': val_loss_mean, 'log': tensorboard_logs}
class EncDecMaskedTokenPredModel(SpeechEncDecSelfSupervisedModel): """ Speech self-supervised model that performs masked token prediction on the encoder output. """ def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device """ batch = move_data_to_device(batch, device) return batch @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ results = [] return results def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg, trainer) del self.decoder_ssl # delete unused decoder from parent class if self.cfg.get("mask_position", "pre_conv") == "post_conv": # adjust config for post-convolution masking self.cfg.quantizer.feat_in = self.cfg.encoder.d_model self.cfg.masking.feat_in = self.cfg.encoder.d_model self.cfg.masking.block_size = self.cfg.masking.block_size // self.cfg.encoder.subsampling_factor self.cfg.loss.combine_time_steps = 1 self.quantizer = self.from_config_dict(self.cfg.quantizer) self.mask_processor = self.from_config_dict(self.cfg.masking) self.encoder = self.from_config_dict(self.cfg.encoder) self.decoder = self.from_config_dict(self.cfg.decoder) self.loss = self.from_config_dict(self.cfg.loss) self.pre_encoder = None if self.cfg.get("mask_position", "pre_conv") == "post_conv": # hacked to mask features after convolutional sub-sampling self.pre_encoder = ConvFeatureMaksingWrapper(self.encoder.pre_encode, self.mask_processor) self.encoder.pre_encode = self.pre_encoder @property def oomptimizer_schema(self) -> dict: """ Return a typing schema for optimal batch size calibration for various sequence lengths using OOMptimizer. """ return { "cls": tuple, "inputs": [ {"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, {"type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, ], } @property def input_types(self) -> Optional[Dict[str, NeuralType]]: if hasattr(self.preprocessor, '_sample_rate'): input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) else: input_signal_eltype = AudioSignal() return { "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "apply_mask": NeuralType(optional=True), } @property def output_types(self) -> Optional[Dict[str, NeuralType]]: if self.cfg.num_books == 1 and self.cfg.squeeze_single: logprobs = NeuralType(('B', 'T', 'C'), LogprobsType()) tokens = NeuralType(('B', 'T'), LabelsType()) else: logprobs = NeuralType(('B', 'T', 'C', 'H'), LogprobsType()) tokens = NeuralType(('B', 'T', 'H'), LabelsType()) return { "logprobs": logprobs, "encoded_len": NeuralType(tuple('B'), LengthsType()), "masks": NeuralType(('B', 'D', 'T'), SpectrogramType()), "tokens": tokens, } @typecheck() def forward( self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, apply_mask=False, ): has_input_signal = input_signal is not None and input_signal_length is not None has_processed_signal = processed_signal is not None and processed_signal_length is not None if (has_input_signal ^ has_processed_signal) == False: raise ValueError( f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " " with ``processed_signal`` and ``processed_signal_len`` arguments." ) if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( input_signal=input_signal, length=input_signal_length, ) if self.pre_encoder is not None: # mask after convolutional sub-sampling self.pre_encoder.set_masking_enabled(apply_mask=apply_mask) encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) masks = self.pre_encoder.get_current_mask() feats = self.pre_encoder.get_current_feat() _, tokens = self.quantizer(input_signal=feats.transpose(1, 2)) else: _, tokens = self.quantizer(input_signal=processed_signal) if apply_mask: masked_signal, masks = self.mask_processor( input_feats=processed_signal, input_lengths=processed_signal_length ) else: masked_signal = processed_signal masks = torch.zeros_like(processed_signal) encoded, encoded_len = self.encoder(audio_signal=masked_signal, length=processed_signal_length) log_probs = self.decoder(encoder_output=encoded) return log_probs, encoded_len, masks, tokens def training_step(self, batch, batch_idx=0): input_signal, input_signal_length = batch[0], batch[1] if isinstance(batch, DALIOutputs) and batch.has_processed_signal: log_probs, encoded_len, masks, tokens = self.forward( processed_signal=input_signal, processed_signal_length=input_signal_length, apply_mask=True ) else: log_probs, encoded_len, masks, tokens = self.forward( input_signal=input_signal, input_signal_length=input_signal_length, apply_mask=True ) loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len) tensorboard_logs = { 'learning_rate': self._optimizer.param_groups[0]['lr'], 'global_step': self.trainer.global_step, 'train_loss': loss_value, } return {'loss': loss_value, 'log': tensorboard_logs} def inference_pass(self, batch, batch_idx=0, dataloader_idx=0, mode='val', apply_mask=False): input_signal, input_signal_length = batch[0], batch[1] if isinstance(batch, DALIOutputs) and batch.has_processed_signal: log_probs, encoded_len, masks, tokens = self.forward( processed_signal=input_signal, processed_signal_length=input_signal_length, apply_mask=apply_mask ) else: log_probs, encoded_len, masks, tokens = self.forward( input_signal=input_signal, input_signal_length=input_signal_length, apply_mask=apply_mask ) loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len) return {f'{mode}_loss': loss_value} def validation_step(self, batch, batch_idx=0, dataloader_idx=0): metrics = self.inference_pass(batch, batch_idx, dataloader_idx, apply_mask=True) if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: self.validation_step_outputs[dataloader_idx].append(metrics) else: self.validation_step_outputs.append(metrics) return metrics def test_step(self, batch, batch_idx=0, dataloader_idx=0): metrics = self.inference_pass(batch, batch_idx, dataloader_idx, mode="test", apply_mask=True) if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: self.validation_step_outputs[dataloader_idx].append(metrics) else: self.validation_step_outputs.append(metrics) return metrics def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): loss_list = [] for i, x in enumerate(outputs): if not isinstance(x, dict): logging.warning(f'Batch {i} output in validation dataloader {dataloader_idx} is not a dictionary: {x}') if 'val_loss' in x: loss_list.append(x['val_loss']) else: logging.warning( f'Batch {i} output in validation dataloader {dataloader_idx} does not have key `val_loss`: {x}' ) if len(loss_list) == 0: logging.warning( f'Epoch {self.current_epoch} received no batches for validation dataloader {dataloader_idx}.' ) return {} val_loss_mean = torch.stack(loss_list).mean() tensorboard_logs = {'val_loss': val_loss_mean} return {'val_loss': val_loss_mean, 'log': tensorboard_logs} def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() tensorboard_logs = {'test_loss': test_loss_mean} return {'test_loss': test_loss_mean, 'log': tensorboard_logs}
[docs] class EncDecDenoiseMaskedTokenPredModel(EncDecMaskedTokenPredModel): """ Model class that performs denoising and masked token prediction for speech self-supervised learning. Please refer to the NEST paper for more details: https://arxiv.org/abs/2408.13106 """ @property def oomptimizer_schema(self) -> dict: """ Return a typing schema for optimal batch size calibration for various sequence lengths using OOMptimizer. """ return { "cls": ssl_dataset.AudioNoiseBatch, "inputs": [ {"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input", "name": "audio"}, {"type": NeuralType(("B",), LengthsType()), "seq_length": "input", "name": "audio_len"}, {"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input", "name": "noise"}, {"type": NeuralType(("B",), LengthsType()), "seq_length": "input", "name": "noise_len"}, {"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input", "name": "noisy_audio"}, {"type": NeuralType(("B",), LengthsType()), "seq_length": "input", "name": "noisy_audio_len"}, ], } def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg, trainer) def _setup_dataloader_from_config(self, config: Optional[Dict]): audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate') if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, global_rank=self.global_rank, world_size=self.world_size, dataset=ssl_dataset.LhotseAudioNoiseDataset( noise_manifest=config.get('noise_manifest', None), batch_augmentor_cfg=config.get('batch_augmentor', None), ), ) dataset = ssl_dataset.get_audio_noise_dataset_from_config( config, global_rank=self.global_rank, world_size=self.world_size, ) shuffle = config['shuffle'] if isinstance(dataset, torch.utils.data.IterableDataset): shuffle = False if hasattr(dataset, 'collate_fn'): collate_fn = dataset.collate_fn elif hasattr(dataset.datasets[0], 'collate_fn'): # support datasets that are lists of entries collate_fn = dataset.datasets[0].collate_fn else: # support datasets that are lists of lists collate_fn = dataset.datasets[0].datasets[0].collate_fn return torch.utils.data.DataLoader( dataset=dataset, batch_size=config['batch_size'], collate_fn=collate_fn, drop_last=config.get('drop_last', False), shuffle=shuffle, num_workers=config.get('num_workers', 0), pin_memory=config.get('pin_memory', False), ) @property def input_types(self) -> Optional[Dict[str, NeuralType]]: if hasattr(self.preprocessor, '_sample_rate'): input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) else: input_signal_eltype = AudioSignal() return { "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "noise_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), "noise_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "processed_noise_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), "processed_noise_length": NeuralType(tuple('B'), LengthsType(), optional=True), "noisy_input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), "noisy_input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "processed_noisy_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), "processed_noisy_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "apply_mask": NeuralType(optional=True), } @property def output_types(self) -> Optional[Dict[str, NeuralType]]: if self.cfg.num_books == 1 and self.cfg.squeeze_single: logprobs = NeuralType(('B', 'T', 'C'), LogprobsType()) tokens = NeuralType(('B', 'T'), LabelsType()) else: logprobs = NeuralType(('B', 'T', 'C', 'H'), LogprobsType()) tokens = NeuralType(('B', 'T', 'H'), LabelsType()) return { "logprobs": logprobs, "encoded_len": NeuralType(tuple('B'), LengthsType()), "masks": NeuralType(('B', 'D', 'T'), SpectrogramType()), "tokens": tokens, }
[docs] @typecheck() def forward( self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, noise_signal=None, # noqa noise_signal_length=None, # noqa processed_noise_signal=None, # noqa processed_noise_signal_length=None, # noqa noisy_input_signal=None, noisy_input_signal_length=None, processed_noisy_input_signal=None, processed_noisy_input_signal_length=None, apply_mask=False, ): has_input_signal = input_signal is not None and input_signal_length is not None has_processed_signal = processed_signal is not None and processed_signal_length is not None if (has_input_signal ^ has_processed_signal) == False: raise ValueError( f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " " with ``processed_signal`` and ``processed_signal_len`` arguments." ) if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( input_signal=input_signal, length=input_signal_length, ) ### Following code snipet is not used but kept for future reference # # has_noise_signal = noise_signal is not None and noise_signal_length is not None # has_processed_noise_signal = processed_noise_signal is not None and processed_noise_signal_length is not None # if (has_noise_signal ^ has_processed_noise_signal) == False: # raise ValueError( # f"{self} Arguments ``noise_signal`` and ``noise_signal_length`` are mutually exclusive " # " with ``processed_noise_signal`` and ``processed_noise_signal_len`` arguments." # ) # if not has_processed_noise_signal: # processed_noise_signal, processed_noise_signal_length = self.preprocessor( # input_signal=noise_signal, # length=noise_signal_length, # ) has_noisy_input_signal = noisy_input_signal is not None and noisy_input_signal_length is not None has_processed_noisy_input_signal = ( processed_noisy_input_signal is not None and processed_noisy_input_signal_length is not None ) if (has_noisy_input_signal ^ has_processed_noisy_input_signal) == False: raise ValueError( f"{self} Arguments ``noisy_input_signal`` and ``noisy_input_signal_length`` are mutually exclusive " " with ``processed_noisy_input_signal`` and ``processed_noisy_input_signal_len`` arguments." ) if not has_processed_noisy_input_signal: processed_noisy_input_signal, processed_noisy_input_signal_length = self.preprocessor( input_signal=noisy_input_signal, length=noisy_input_signal_length, ) if self.pre_encoder is not None: # mask after convolutional sub-sampling feats, _ = self.pre_encoder.pre_encode(x=processed_signal, lengths=processed_signal_length) _, tokens = self.quantizer(input_signal=feats.transpose(1, 2)) self.pre_encoder.set_masking_enabled(apply_mask=apply_mask) encoded, encoded_len = self.encoder( audio_signal=processed_noisy_input_signal, length=processed_noisy_input_signal_length ) masks = self.pre_encoder.get_current_mask() else: _, tokens = self.quantizer(input_signal=processed_signal) if apply_mask: masked_signal, masks = self.mask_processor( input_feats=processed_noisy_input_signal, input_lengths=processed_noisy_input_signal_length ) else: masked_signal = processed_noisy_input_signal masks = torch.zeros_like(processed_noisy_input_signal) encoded, encoded_len = self.encoder(audio_signal=masked_signal, length=processed_noisy_input_signal_length) log_probs = self.decoder(encoder_output=encoded) return log_probs, encoded_len, masks, tokens
[docs] def training_step(self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int): log_probs, encoded_len, masks, tokens = self.forward( input_signal=batch.audio, input_signal_length=batch.audio_len, noise_signal=batch.noise, noise_signal_length=batch.noise_len, noisy_input_signal=batch.noisy_audio, noisy_input_signal_length=batch.noisy_audio_len, apply_mask=True, ) loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len) tensorboard_logs = { 'learning_rate': self._optimizer.param_groups[0]['lr'], 'global_step': self.trainer.global_step, 'train_loss': loss_value, } return {'loss': loss_value, 'log': tensorboard_logs}
[docs] def inference_pass( self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int, dataloader_idx: int = 0, mode: str = 'val', apply_mask: bool = True, ): log_probs, encoded_len, masks, tokens = self.forward( input_signal=batch.audio, input_signal_length=batch.audio_len, noise_signal=batch.noise, noise_signal_length=batch.noise_len, noisy_input_signal=batch.noisy_audio, noisy_input_signal_length=batch.noisy_audio_len, apply_mask=apply_mask, ) loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len) return {f'{mode}_loss': loss_value}