# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from math import ceil
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from lightning.pytorch import Trainer
from omegaconf import DictConfig
from nemo.collections.asr.data import audio_to_text_dataset, ssl_dataset
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
from nemo.collections.asr.modules.ssl_modules.masking import ConvFeatureMaksingWrapper
from nemo.collections.asr.parts.mixins import ASRModuleMixin
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.data.utils import move_data_to_device
from nemo.collections.common.parts.preprocessing.parsers import make_parser
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.classes.mixins import AccessMixin, set_access_cfg
from nemo.core.neural_types import (
AcousticEncodedRepresentation,
AudioSignal,
LabelsType,
LengthsType,
LogprobsType,
NeuralType,
SpectrogramType,
)
from nemo.utils import logging
__all__ = ['SpeechEncDecSelfSupervisedModel', 'EncDecMaskedTokenPredModel', 'EncDecDenoiseMaskedTokenPredModel']
[docs]
class SpeechEncDecSelfSupervisedModel(ModelPT, ASRModuleMixin, AccessMixin):
"""Base class for encoder-decoder models used for self-supervised encoder pre-training"""
[docs]
@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
results = []
model = PretrainedModelInfo(
pretrained_model_name="ssl_en_conformer_large",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_large",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_large/versions/1.10.1/files/ssl_en_conformer_large.nemo",
)
results.append(model)
model = PretrainedModelInfo(
pretrained_model_name="ssl_en_conformer_xlarge",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_xlarge",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_xlarge/versions/1.10.0/files/ssl_en_conformer_xlarge.nemo",
)
results.append(model)
return results
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
# Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
self.world_size = 1
if trainer is not None:
self.world_size = trainer.world_size
super().__init__(cfg=cfg, trainer=trainer)
self.preprocessor = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.preprocessor)
self.encoder = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.encoder)
self.decoder_losses = None
if "loss_list" in self._cfg:
self.decoder_losses = {}
self.loss_alphas = {}
self.start_step = {}
self.output_from_layer = {}
self.transpose_encoded = {}
self.targets_from_loss = {}
self.decoder_losses_active = {}
# need to be separate for moduledict
for decoder_loss_name, decoder_loss_cfg in self._cfg.loss_list.items():
if not decoder_loss_cfg.get("is_active", True): # active by default
continue
new_decoder_loss = {
'decoder': SpeechEncDecSelfSupervisedModel.from_config_dict(decoder_loss_cfg.decoder),
'loss': SpeechEncDecSelfSupervisedModel.from_config_dict(decoder_loss_cfg.loss),
}
new_decoder_loss = nn.ModuleDict(new_decoder_loss)
self.decoder_losses[decoder_loss_name] = new_decoder_loss
self.loss_alphas[decoder_loss_name] = decoder_loss_cfg.get("loss_alpha", 1.0)
self.output_from_layer[decoder_loss_name] = decoder_loss_cfg.get("output_from_layer", None)
self.targets_from_loss[decoder_loss_name] = decoder_loss_cfg.get("targets_from_loss", None)
self.start_step[decoder_loss_name] = decoder_loss_cfg.get("start_step", 0)
self.transpose_encoded[decoder_loss_name] = decoder_loss_cfg.get("transpose_encoded", False)
self.decoder_losses_active[decoder_loss_name] = True
self.decoder_losses = nn.ModuleDict(self.decoder_losses)
else:
self.decoder_ssl = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.decoder)
self.loss = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.loss)
self.spec_augmentation = SpeechEncDecSelfSupervisedModel.from_config_dict(self._cfg.spec_augment)
# dropout for features/spectrograms (applied before masking)
self.dropout_features = (
torch.nn.Dropout(self._cfg.dropout_features) if "dropout_features" in self._cfg else None
)
# dropout for targets (applied before quantization)
self.dropout_features_q = (
torch.nn.Dropout(self._cfg.dropout_features_q) if "dropout_features_q" in self._cfg else None
)
# Feature penalty for preprocessor encodings (for Wav2Vec training)
if "feature_penalty" in self._cfg:
self.feat_pen, self.pen_factor = 0.0, self._cfg.feature_penalty
else:
self.feat_pen, self.pen_factor = None, None
if "access" in self._cfg:
set_access_cfg(self._cfg.access, self.model_guid)
self.apply_masking = True
def _setup_dataloader_from_config(self, config: Optional[Dict]):
if 'augmentor' in config:
augmentor = process_augmentations(config['augmentor'])
else:
augmentor = None
# Automatically inject args from model config to dataloader config
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=LhotseSpeechToTextBpeDataset(
tokenizer=make_parser(
labels=config.get('labels', None),
name=config.get('parser', 'en'),
unk_id=config.get('unk_index', -1),
blank_id=config.get('blank_index', -1),
do_normalize=config.get('normalize_transcripts', False),
),
),
)
shuffle = config['shuffle']
device = 'gpu' if torch.cuda.is_available() else 'cpu'
if config.get('use_dali', False):
device_id = self.local_rank if device == 'gpu' else None
dataset = audio_to_text_dataset.get_dali_char_dataset(
config=config,
shuffle=shuffle,
device_id=device_id,
global_rank=self.global_rank,
world_size=self.world_size,
preprocessor_cfg=self._cfg.preprocessor,
)
return dataset
# Instantiate tarred dataset loader or normal dataset loader
if config.get('is_tarred', False):
if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
'manifest_filepath' in config and config['manifest_filepath'] is None
):
logging.warning(
"Could not load dataset as `manifest_filepath` was None or "
f"`tarred_audio_filepaths` is None. Provided config : {config}"
)
return None
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
dataset = audio_to_text_dataset.get_tarred_dataset(
config=config,
shuffle_n=shuffle_n,
global_rank=self.global_rank,
world_size=self.world_size,
augmentor=augmentor,
)
shuffle = False
else:
if 'manifest_filepath' in config and config['manifest_filepath'] is None:
logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
return None
dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor)
if hasattr(dataset, 'collate_fn'):
collate_fn = dataset.collate_fn
elif hasattr(dataset.datasets[0], 'collate_fn'):
# support datasets that are lists of entries
collate_fn = dataset.datasets[0].collate_fn
else:
# support datasets that are lists of lists
collate_fn = dataset.datasets[0].datasets[0].collate_fn
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=config['batch_size'],
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
shuffle=shuffle,
num_workers=config.get('num_workers', 0),
pin_memory=config.get('pin_memory', False),
)
[docs]
def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]):
"""
Sets up the training data loader via a Dict-like object.
Args:
train_data_config: A config that contains the information regarding construction
of an ASR Training dataset.
Supported Datasets:
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
- :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
"""
if 'shuffle' not in train_data_config:
train_data_config['shuffle'] = True
# preserve config
self._update_dataset_config(dataset_name='train', config=train_data_config)
self._train_dl = self._setup_dataloader_from_config(config=train_data_config)
# Need to set this because if using an IterableDataset, the length of the dataloader is the total number
# of samples rather than the number of batches, and this messes up the tqdm progress bar.
# So we set the number of steps manually (to the correct number) to fix this.
if (
self._train_dl is not None
and hasattr(self._train_dl, 'dataset')
and isinstance(self._train_dl.dataset, torch.utils.data.IterableDataset)
):
# We also need to check if limit_train_batches is already set.
# If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
# and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float):
self._trainer.limit_train_batches = int(
self._trainer.limit_train_batches
* ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size'])
)
elif self._trainer is None:
logging.warning(
"Model Trainer was not set before constructing the dataset, incorrect number of "
"training batches will be used. Please set the trainer and rebuild the dataset."
)
[docs]
def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]):
"""
Sets up the validation data loader via a Dict-like object.
Args:
val_data_config: A config that contains the information regarding construction
of an ASR Training dataset.
Supported Datasets:
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
- :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
"""
if 'shuffle' not in val_data_config:
val_data_config['shuffle'] = False
# preserve config
self._update_dataset_config(dataset_name='validation', config=val_data_config)
self._validation_dl = self._setup_dataloader_from_config(config=val_data_config)
# Need to set this because if using an IterableDataset, the length of the dataloader is the total number
# of samples rather than the number of batches, and this messes up the tqdm progress bar.
# So we set the number of steps manually (to the correct number) to fix this.
if (
self._validation_dl is not None
and hasattr(self._validation_dl, 'dataset')
and isinstance(self._validation_dl.dataset, torch.utils.data.IterableDataset)
):
# We also need to check if limit_train_batches is already set.
# If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
# and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
if isinstance(self._trainer.limit_val_batches, float):
self._trainer.limit_val_batches = int(
self._trainer.limit_val_batches
* ceil((len(self._validation_dl.dataset) / self.world_size) / val_data_config['batch_size'])
)
@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
if hasattr(self.preprocessor, '_sample_rate'):
input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate)
else:
input_signal_eltype = AudioSignal()
return {
"input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True),
"input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
"processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"targets": NeuralType(('B', 'T'), LabelsType(), optional=True),
"target_lengths": NeuralType(tuple('B'), LengthsType(), optional=True),
}
@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {
"spectrograms": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"spec_masks": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"encoded": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
}
[docs]
@typecheck()
def forward(
self,
input_signal=None,
input_signal_length=None,
processed_signal=None,
processed_signal_length=None,
):
"""
Forward pass of the model.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
of shape [B, T]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
processed_signal: Tensor that represents a batch of processed audio signals,
of shape (B, D, T) that has undergone processing via some DALI preprocessor.
processed_signal_length: Vector of length B, that contains the individual lengths of the
processed audio sequences.
Returns:
A tuple of 4 elements -
1) Processed spectrograms of shape [B, D, T].
2) Masks applied to spectrograms of shape [B, D, T].
3) The encoded features tensor of shape [B, D, T].
2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
"""
# Reset access registry
if self.is_access_enabled(self.model_guid):
self.reset_registry()
# Check for special flag for validation step
if hasattr(self, '_in_validation_step'):
in_validation_step = self._in_validation_step
else:
in_validation_step = False
# reset module registry from AccessMixin
if (
(self.training or in_validation_step)
and self.decoder_losses is not None
and self.output_from_layer is not None
and len(self.output_from_layer) > 0
):
layer_names = list(self.output_from_layer.values())
register_layer = any([name is not None for name in layer_names])
if register_layer:
self.access_cfg['save_encoder_tensors'] = True
self.set_access_enabled(access_enabled=True, guid=self.model_guid)
has_input_signal = input_signal is not None and input_signal_length is not None
has_processed_signal = processed_signal is not None and processed_signal_length is not None
if (has_input_signal ^ has_processed_signal) == False:
raise ValueError(
f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
" with ``processed_signal`` and ``processed_signal_len`` arguments."
)
if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal,
length=input_signal_length,
)
if self.pen_factor:
self.feat_pen = processed_signal.float().pow(2).mean() * self.pen_factor
spectrograms = processed_signal.detach().clone()
if self.dropout_features:
processed_signal = self.dropout_features(processed_signal)
if self.dropout_features_q:
spectrograms = self.dropout_features_q(spectrograms)
if self.apply_masking:
processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)
masked_spectrograms = processed_signal.detach()
spec_masks = torch.logical_and(masked_spectrograms < 1e-5, masked_spectrograms > -1e-5).float()
for idx, proc_len in enumerate(processed_signal_length):
spec_masks[idx, :, proc_len:] = 0.0
encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
return spectrograms, spec_masks, encoded, encoded_len
[docs]
def decoder_loss_step(self, spectrograms, spec_masks, encoded, encoded_len, targets=None, target_lengths=None):
"""
Forward pass through all decoders and calculate corresponding losses.
Args:
spectrograms: Processed spectrograms of shape [B, D, T].
spec_masks: Masks applied to spectrograms of shape [B, D, T].
encoded: The encoded features tensor of shape [B, D, T].
encoded_len: The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
targets: Optional target labels of shape [B, T]
target_lengths: Optional target label lengths of shape [B]
Returns:
A tuple of 2 elements -
1) Total sum of losses weighted by corresponding loss_alphas
2) Dictionary of unweighted losses
"""
loss_val_dict = {}
if self.decoder_losses is None:
if hasattr(self.decoder_ssl, "needs_labels") and self.decoder_ssl.needs_labels:
outputs = self.decoder_ssl(encoder_output=encoded, targets=targets, target_lengths=target_lengths)
else:
outputs = self.decoder_ssl(encoder_output=encoded)
if self.loss.needs_labels:
loss_value = self.loss(
spec_masks=spec_masks,
decoder_outputs=outputs,
targets=targets,
decoder_lengths=encoded_len,
target_lengths=target_lengths,
)
else:
loss_value = self.loss(spectrograms=spectrograms, spec_masks=spec_masks, decoder_outputs=outputs)
else:
loss_value = encoded.new_zeros(1)
outputs = {}
registry = self.get_module_registry(self.encoder)
for dec_loss_name, dec_loss in self.decoder_losses.items():
# loop through decoders and corresponding losses
if not self.decoder_losses_active[dec_loss_name]:
continue
if self.output_from_layer[dec_loss_name] is None:
dec_input = encoded
else:
# extract output from specified layer using AccessMixin registry
dec_input = registry[self.output_from_layer[dec_loss_name]]['encoder'][-1]
if self.transpose_encoded[dec_loss_name]:
dec_input = dec_input.transpose(-2, -1)
if self.targets_from_loss[dec_loss_name] is not None:
# extract targets from specified loss
target_loss = self.targets_from_loss[dec_loss_name]
targets = self.decoder_losses[target_loss]['loss'].target_ids
target_lengths = self.decoder_losses[target_loss]['loss'].target_lengths
if target_lengths is None:
target_lengths = encoded_len
if hasattr(dec_loss['decoder'], "needs_labels") and dec_loss['decoder'].needs_labels:
# if we are using a decoder which needs labels, provide them
outputs[dec_loss_name] = dec_loss['decoder'](
encoder_output=dec_input, targets=targets, target_lengths=target_lengths
)
else:
outputs[dec_loss_name] = dec_loss['decoder'](encoder_output=dec_input)
current_loss = dec_loss['loss']
if current_loss.needs_labels:
# if we are using a loss which needs labels, provide them
current_loss_value = current_loss(
spec_masks=spec_masks,
decoder_outputs=outputs[dec_loss_name],
targets=targets,
decoder_lengths=encoded_len,
target_lengths=target_lengths,
)
else:
current_loss_value = current_loss(
spectrograms=spectrograms,
spec_masks=spec_masks,
decoder_outputs=outputs[dec_loss_name],
decoder_lengths=encoded_len,
)
loss_value = loss_value + current_loss_value * self.loss_alphas[dec_loss_name]
loss_val_dict[dec_loss_name] = current_loss_value
return loss_value, loss_val_dict
# PTL-specific methods
[docs]
def training_step(self, batch, batch_nb):
signal, signal_len, targets, target_lengths = batch
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
spectrograms, spec_masks, encoded, encoded_len = self.forward(
processed_signal=signal,
processed_signal_length=signal_len,
)
else:
spectrograms, spec_masks, encoded, encoded_len = self.forward(
input_signal=signal,
input_signal_length=signal_len,
)
if self.decoder_losses is not None:
for dec_loss_name, dec_loss in self.decoder_losses.items():
self.decoder_losses_active[dec_loss_name] = self.trainer.global_step >= self.start_step[dec_loss_name]
loss = dec_loss['loss']
if hasattr(loss, "set_num_updates"):
loss.set_num_updates(self.trainer.global_step)
else:
if hasattr(self.loss, "set_num_updates"):
self.loss.set_num_updates(self.trainer.global_step)
loss_value, loss_val_dict = self.decoder_loss_step(
spectrograms, spec_masks, encoded, encoded_len, targets, target_lengths
)
tensorboard_logs = {
'learning_rate': self._optimizer.param_groups[0]['lr'],
'global_step': self.trainer.global_step,
}
for loss_name, loss_val in loss_val_dict.items():
tensorboard_logs['train_' + loss_name] = loss_val
if self.feat_pen:
loss_value += self.feat_pen
# Reset access registry
self.reset_registry()
return {'loss': loss_value, 'log': tensorboard_logs}
[docs]
def validation_pass(self, batch, batch_idx, dataloader_idx=0):
# Set flag to register tensors
self._in_validation_step = True
signal, signal_len, targets, target_lengths = batch
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
spectrograms, spec_masks, encoded, encoded_len = self.forward(
processed_signal=signal,
processed_signal_length=signal_len,
)
else:
spectrograms, spec_masks, encoded, encoded_len = self.forward(
input_signal=signal,
input_signal_length=signal_len,
)
if self.decoder_losses is not None:
for dec_loss_name, dec_loss in self.decoder_losses.items():
self.decoder_losses_active[dec_loss_name] = self.trainer.global_step >= self.start_step[dec_loss_name]
loss_value, _ = self.decoder_loss_step(spectrograms, spec_masks, encoded, encoded_len, targets, target_lengths)
if self.feat_pen:
loss_value += self.feat_pen
# reset access registry
self.reset_registry()
del self._in_validation_step
metrics = {'val_loss': loss_value}
return metrics
[docs]
def validation_step(self, batch, batch_idx, dataloader_idx=0):
metrics = self.validation_pass(batch, batch_idx, dataloader_idx)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
return metrics
[docs]
def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': val_loss_mean}
return {'val_loss': val_loss_mean, 'log': tensorboard_logs}
class EncDecMaskedTokenPredModel(SpeechEncDecSelfSupervisedModel):
"""
Speech self-supervised model that performs masked token prediction on the encoder output.
"""
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
"""
PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
"""
batch = move_data_to_device(batch, device)
return batch
@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
results = []
return results
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg, trainer)
del self.decoder_ssl # delete unused decoder from parent class
if self.cfg.get("mask_position", "pre_conv") == "post_conv":
# adjust config for post-convolution masking
self.cfg.quantizer.feat_in = self.cfg.encoder.d_model
self.cfg.masking.feat_in = self.cfg.encoder.d_model
self.cfg.masking.block_size = self.cfg.masking.block_size // self.cfg.encoder.subsampling_factor
self.cfg.loss.combine_time_steps = 1
self.quantizer = self.from_config_dict(self.cfg.quantizer)
self.mask_processor = self.from_config_dict(self.cfg.masking)
self.encoder = self.from_config_dict(self.cfg.encoder)
self.decoder = self.from_config_dict(self.cfg.decoder)
self.loss = self.from_config_dict(self.cfg.loss)
self.pre_encoder = None
if self.cfg.get("mask_position", "pre_conv") == "post_conv":
# hacked to mask features after convolutional sub-sampling
self.pre_encoder = ConvFeatureMaksingWrapper(self.encoder.pre_encode, self.mask_processor)
self.encoder.pre_encode = self.pre_encoder
@property
def oomptimizer_schema(self) -> dict:
"""
Return a typing schema for optimal batch size calibration for various
sequence lengths using OOMptimizer.
"""
return {
"cls": tuple,
"inputs": [
{"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"},
{"type": NeuralType(("B",), LengthsType()), "seq_length": "input"},
],
}
@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
if hasattr(self.preprocessor, '_sample_rate'):
input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate)
else:
input_signal_eltype = AudioSignal()
return {
"input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True),
"input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
"processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"apply_mask": NeuralType(optional=True),
}
@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
if self.cfg.num_books == 1 and self.cfg.squeeze_single:
logprobs = NeuralType(('B', 'T', 'C'), LogprobsType())
tokens = NeuralType(('B', 'T'), LabelsType())
else:
logprobs = NeuralType(('B', 'T', 'C', 'H'), LogprobsType())
tokens = NeuralType(('B', 'T', 'H'), LabelsType())
return {
"logprobs": logprobs,
"encoded_len": NeuralType(tuple('B'), LengthsType()),
"masks": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"tokens": tokens,
}
@typecheck()
def forward(
self,
input_signal=None,
input_signal_length=None,
processed_signal=None,
processed_signal_length=None,
apply_mask=False,
):
has_input_signal = input_signal is not None and input_signal_length is not None
has_processed_signal = processed_signal is not None and processed_signal_length is not None
if (has_input_signal ^ has_processed_signal) == False:
raise ValueError(
f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
" with ``processed_signal`` and ``processed_signal_len`` arguments."
)
if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal,
length=input_signal_length,
)
if self.pre_encoder is not None:
# mask after convolutional sub-sampling
self.pre_encoder.set_masking_enabled(apply_mask=apply_mask)
encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
masks = self.pre_encoder.get_current_mask()
feats = self.pre_encoder.get_current_feat()
_, tokens = self.quantizer(input_signal=feats.transpose(1, 2))
else:
_, tokens = self.quantizer(input_signal=processed_signal)
if apply_mask:
masked_signal, masks = self.mask_processor(
input_feats=processed_signal, input_lengths=processed_signal_length
)
else:
masked_signal = processed_signal
masks = torch.zeros_like(processed_signal)
encoded, encoded_len = self.encoder(audio_signal=masked_signal, length=processed_signal_length)
log_probs = self.decoder(encoder_output=encoded)
return log_probs, encoded_len, masks, tokens
def training_step(self, batch, batch_idx=0):
input_signal, input_signal_length = batch[0], batch[1]
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
log_probs, encoded_len, masks, tokens = self.forward(
processed_signal=input_signal, processed_signal_length=input_signal_length, apply_mask=True
)
else:
log_probs, encoded_len, masks, tokens = self.forward(
input_signal=input_signal, input_signal_length=input_signal_length, apply_mask=True
)
loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len)
tensorboard_logs = {
'learning_rate': self._optimizer.param_groups[0]['lr'],
'global_step': self.trainer.global_step,
'train_loss': loss_value,
}
return {'loss': loss_value, 'log': tensorboard_logs}
def inference_pass(self, batch, batch_idx=0, dataloader_idx=0, mode='val', apply_mask=False):
input_signal, input_signal_length = batch[0], batch[1]
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
log_probs, encoded_len, masks, tokens = self.forward(
processed_signal=input_signal, processed_signal_length=input_signal_length, apply_mask=apply_mask
)
else:
log_probs, encoded_len, masks, tokens = self.forward(
input_signal=input_signal, input_signal_length=input_signal_length, apply_mask=apply_mask
)
loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len)
return {f'{mode}_loss': loss_value}
def validation_step(self, batch, batch_idx=0, dataloader_idx=0):
metrics = self.inference_pass(batch, batch_idx, dataloader_idx, apply_mask=True)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
return metrics
def test_step(self, batch, batch_idx=0, dataloader_idx=0):
metrics = self.inference_pass(batch, batch_idx, dataloader_idx, mode="test", apply_mask=True)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
return metrics
def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0):
loss_list = []
for i, x in enumerate(outputs):
if not isinstance(x, dict):
logging.warning(f'Batch {i} output in validation dataloader {dataloader_idx} is not a dictionary: {x}')
if 'val_loss' in x:
loss_list.append(x['val_loss'])
else:
logging.warning(
f'Batch {i} output in validation dataloader {dataloader_idx} does not have key `val_loss`: {x}'
)
if len(loss_list) == 0:
logging.warning(
f'Epoch {self.current_epoch} received no batches for validation dataloader {dataloader_idx}.'
)
return {}
val_loss_mean = torch.stack(loss_list).mean()
tensorboard_logs = {'val_loss': val_loss_mean}
return {'val_loss': val_loss_mean, 'log': tensorboard_logs}
def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
tensorboard_logs = {'test_loss': test_loss_mean}
return {'test_loss': test_loss_mean, 'log': tensorboard_logs}
[docs]
class EncDecDenoiseMaskedTokenPredModel(EncDecMaskedTokenPredModel):
"""
Model class that performs denoising and masked token prediction for speech self-supervised learning.
Please refer to the NEST paper for more details: https://arxiv.org/abs/2408.13106
"""
@property
def oomptimizer_schema(self) -> dict:
"""
Return a typing schema for optimal batch size calibration for various
sequence lengths using OOMptimizer.
"""
return {
"cls": ssl_dataset.AudioNoiseBatch,
"inputs": [
{"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input", "name": "audio"},
{"type": NeuralType(("B",), LengthsType()), "seq_length": "input", "name": "audio_len"},
{"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input", "name": "noise"},
{"type": NeuralType(("B",), LengthsType()), "seq_length": "input", "name": "noise_len"},
{"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input", "name": "noisy_audio"},
{"type": NeuralType(("B",), LengthsType()), "seq_length": "input", "name": "noisy_audio_len"},
],
}
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg, trainer)
def _setup_dataloader_from_config(self, config: Optional[Dict]):
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=ssl_dataset.LhotseAudioNoiseDataset(
noise_manifest=config.get('noise_manifest', None),
batch_augmentor_cfg=config.get('batch_augmentor', None),
),
)
dataset = ssl_dataset.get_audio_noise_dataset_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
)
shuffle = config['shuffle']
if isinstance(dataset, torch.utils.data.IterableDataset):
shuffle = False
if hasattr(dataset, 'collate_fn'):
collate_fn = dataset.collate_fn
elif hasattr(dataset.datasets[0], 'collate_fn'):
# support datasets that are lists of entries
collate_fn = dataset.datasets[0].collate_fn
else:
# support datasets that are lists of lists
collate_fn = dataset.datasets[0].datasets[0].collate_fn
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=config['batch_size'],
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
shuffle=shuffle,
num_workers=config.get('num_workers', 0),
pin_memory=config.get('pin_memory', False),
)
@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
if hasattr(self.preprocessor, '_sample_rate'):
input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate)
else:
input_signal_eltype = AudioSignal()
return {
"input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True),
"input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
"processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"noise_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True),
"noise_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"processed_noise_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
"processed_noise_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"noisy_input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True),
"noisy_input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"processed_noisy_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
"processed_noisy_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"apply_mask": NeuralType(optional=True),
}
@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
if self.cfg.num_books == 1 and self.cfg.squeeze_single:
logprobs = NeuralType(('B', 'T', 'C'), LogprobsType())
tokens = NeuralType(('B', 'T'), LabelsType())
else:
logprobs = NeuralType(('B', 'T', 'C', 'H'), LogprobsType())
tokens = NeuralType(('B', 'T', 'H'), LabelsType())
return {
"logprobs": logprobs,
"encoded_len": NeuralType(tuple('B'), LengthsType()),
"masks": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"tokens": tokens,
}
[docs]
@typecheck()
def forward(
self,
input_signal=None,
input_signal_length=None,
processed_signal=None,
processed_signal_length=None,
noise_signal=None, # noqa
noise_signal_length=None, # noqa
processed_noise_signal=None, # noqa
processed_noise_signal_length=None, # noqa
noisy_input_signal=None,
noisy_input_signal_length=None,
processed_noisy_input_signal=None,
processed_noisy_input_signal_length=None,
apply_mask=False,
):
has_input_signal = input_signal is not None and input_signal_length is not None
has_processed_signal = processed_signal is not None and processed_signal_length is not None
if (has_input_signal ^ has_processed_signal) == False:
raise ValueError(
f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
" with ``processed_signal`` and ``processed_signal_len`` arguments."
)
if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal,
length=input_signal_length,
)
### Following code snipet is not used but kept for future reference
#
# has_noise_signal = noise_signal is not None and noise_signal_length is not None
# has_processed_noise_signal = processed_noise_signal is not None and processed_noise_signal_length is not None
# if (has_noise_signal ^ has_processed_noise_signal) == False:
# raise ValueError(
# f"{self} Arguments ``noise_signal`` and ``noise_signal_length`` are mutually exclusive "
# " with ``processed_noise_signal`` and ``processed_noise_signal_len`` arguments."
# )
# if not has_processed_noise_signal:
# processed_noise_signal, processed_noise_signal_length = self.preprocessor(
# input_signal=noise_signal,
# length=noise_signal_length,
# )
has_noisy_input_signal = noisy_input_signal is not None and noisy_input_signal_length is not None
has_processed_noisy_input_signal = (
processed_noisy_input_signal is not None and processed_noisy_input_signal_length is not None
)
if (has_noisy_input_signal ^ has_processed_noisy_input_signal) == False:
raise ValueError(
f"{self} Arguments ``noisy_input_signal`` and ``noisy_input_signal_length`` are mutually exclusive "
" with ``processed_noisy_input_signal`` and ``processed_noisy_input_signal_len`` arguments."
)
if not has_processed_noisy_input_signal:
processed_noisy_input_signal, processed_noisy_input_signal_length = self.preprocessor(
input_signal=noisy_input_signal,
length=noisy_input_signal_length,
)
if self.pre_encoder is not None:
# mask after convolutional sub-sampling
feats, _ = self.pre_encoder.pre_encode(x=processed_signal, lengths=processed_signal_length)
_, tokens = self.quantizer(input_signal=feats.transpose(1, 2))
self.pre_encoder.set_masking_enabled(apply_mask=apply_mask)
encoded, encoded_len = self.encoder(
audio_signal=processed_noisy_input_signal, length=processed_noisy_input_signal_length
)
masks = self.pre_encoder.get_current_mask()
else:
_, tokens = self.quantizer(input_signal=processed_signal)
if apply_mask:
masked_signal, masks = self.mask_processor(
input_feats=processed_noisy_input_signal, input_lengths=processed_noisy_input_signal_length
)
else:
masked_signal = processed_noisy_input_signal
masks = torch.zeros_like(processed_noisy_input_signal)
encoded, encoded_len = self.encoder(audio_signal=masked_signal, length=processed_noisy_input_signal_length)
log_probs = self.decoder(encoder_output=encoded)
return log_probs, encoded_len, masks, tokens
[docs]
def training_step(self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int):
log_probs, encoded_len, masks, tokens = self.forward(
input_signal=batch.audio,
input_signal_length=batch.audio_len,
noise_signal=batch.noise,
noise_signal_length=batch.noise_len,
noisy_input_signal=batch.noisy_audio,
noisy_input_signal_length=batch.noisy_audio_len,
apply_mask=True,
)
loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len)
tensorboard_logs = {
'learning_rate': self._optimizer.param_groups[0]['lr'],
'global_step': self.trainer.global_step,
'train_loss': loss_value,
}
return {'loss': loss_value, 'log': tensorboard_logs}
[docs]
def inference_pass(
self,
batch: ssl_dataset.AudioNoiseBatch,
batch_idx: int,
dataloader_idx: int = 0,
mode: str = 'val',
apply_mask: bool = True,
):
log_probs, encoded_len, masks, tokens = self.forward(
input_signal=batch.audio,
input_signal_length=batch.audio_len,
noise_signal=batch.noise,
noise_signal_length=batch.noise_len,
noisy_input_signal=batch.noisy_audio,
noisy_input_signal_length=batch.noisy_audio_len,
apply_mask=apply_mask,
)
loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len)
return {f'{mode}_loss': loss_value}