NeMo SSL collection API
Contents
NeMo SSL collection API#
Model Classes#
- class nemo.collections.asr.models.SpeechEncDecSelfSupervisedModel(*args: Any, **kwargs: Any)[source]#
Bases:
nemo.core.classes.modelPT.ModelPT
,nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin
,nemo.core.classes.mixins.access_mixins.AccessMixin
Base class for encoder-decoder models used for self-supervised encoder pre-training
- decoder_loss_step(spectrograms, spec_masks, encoded, encoded_len, targets=None, target_lengths=None)[source]#
Forward pass through all decoders and calculate corresponding losses. :param spectrograms: Processed spectrograms of shape [B, D, T]. :param spec_masks: Masks applied to spectrograms of shape [B, D, T]. :param encoded: The encoded features tensor of shape [B, D, T]. :param encoded_len: The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. :param targets: Optional target labels of shape [B, T] :param target_lengths: Optional target label lengths of shape [B]
- Returns
A tuple of 2 elements - 1) Total sum of losses weighted by corresponding loss_alphas 2) Dictionary of unweighted losses
- forward(input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None)[source]#
Forward pass of the model.
- Parameters
input_signal – Tensor that represents a batch of raw audio signals, of shape [B, T]. T here represents timesteps, with 1 second of audio represented as self.sample_rate number of floating point values.
input_signal_length – Vector of length B, that contains the individual lengths of the audio sequences.
processed_signal – Tensor that represents a batch of processed audio signals, of shape (B, D, T) that has undergone processing via some DALI preprocessor.
processed_signal_length – Vector of length B, that contains the individual lengths of the processed audio sequences.
- Returns
A tuple of 4 elements - 1) Processed spectrograms of shape [B, D, T]. 2) Masks applied to spectrograms of shape [B, D, T]. 3) The encoded features tensor of shape [B, D, T]. 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
- property input_types: Optional[Dict[str, nemo.core.neural_types.neural_type.NeuralType]]#
Define these to enable input neural type checks
- classmethod list_available_models() Optional[nemo.core.classes.common.PretrainedModelInfo] [source]#
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA’s NGC cloud.
- Returns
List of available pre-trained models.
- multi_validation_epoch_end(outputs, dataloader_idx: int = 0)[source]#
Adds support for multiple validation datasets. Should be overriden by subclass, so as to obtain appropriate logs for each of the dataloaders.
- Parameters
outputs – Same as that provided by LightningModule.validation_epoch_end() for a single dataloader.
dataloader_idx – int representing the index of the dataloader.
- Returns
A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be pre-pended by the dataloader prefix.
- property output_types: Optional[Dict[str, nemo.core.neural_types.neural_type.NeuralType]]#
Define these to enable output neural type checks
- setup_training_data(train_data_config: Optional[Union[omegaconf.DictConfig, Dict]])[source]#
Sets up the training data loader via a Dict-like object.
- Parameters
train_data_config – A config that contains the information regarding construction of an ASR Training dataset.
- Supported Datasets:
AudioToCharDALIDataset
- setup_validation_data(val_data_config: Optional[Union[omegaconf.DictConfig, Dict]])[source]#
Sets up the validation data loader via a Dict-like object.
- Parameters
val_data_config – A config that contains the information regarding construction of an ASR Training dataset.
- Supported Datasets:
AudioToCharDALIDataset
Mixins#
- class nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin[source]#
Bases:
nemo.collections.asr.parts.mixins.asr_adapter_mixins.ASRAdapterModelMixin
ASRModuleMixin is a mixin class added to ASR models in order to add methods that are specific to a particular instantiation of a module inside of an ASRModel.
Each method should first check that the module is present within the subclass, and support additional functionality if the corresponding module is present.
- change_conv_asr_se_context_window(context_window: int, update_config: bool = True)[source]#
Update the context window of the SqueezeExcitation module if the provided model contains an encoder which is an instance of ConvASREncoder.
- Parameters
context_window –
An integer representing the number of input timeframes that will be used to compute the context. Each timeframe corresponds to a single window stride of the STFT features.
Say the window_stride = 0.01s, then a context window of 128 represents 128 * 0.01 s of context to compute the Squeeze step.
update_config – Whether to update the config or not with the new context window.
- class nemo.core.classes.mixins.access_mixins.AccessMixin[source]#
Bases:
abc.ABC
Allows access to output of intermediate layers of a model
- property access_cfg#
Returns: The global access config shared across all access mixin modules.
- classmethod get_module_registry(module: torch.nn.Module)[source]#
Extract all registries from named submodules, return dictionary where the keys are the flattened module names, the values are the internal registry of each such module.