NeMo Speaker Diarization API#
Model Classes#
- class nemo.collections.asr.models.ClusteringDiarizer(
- cfg: DictConfig | Any,
- speaker_model=None,
Bases:
Module,Model,DiarizationMixinInference model Class for offline speaker diarization. This class handles required functionality for diarization : Speech Activity Detection, Segmentation, Extract Embeddings, Clustering, Resegmentation and Scoring. All the parameters are passed through config file
- diarize(
- paths2audio_files: List[str] = None,
- batch_size: int = 0,
Diarize files provided through paths2audio_files or manifest file input: paths2audio_files (List[str]): list of paths to file containing audio file batch_size (int): batch_size considered for extraction of speaker embeddings and VAD computation
- classmethod list_available_models()[source]#
Should list all pre-trained models available via NVIDIA NGC cloud. Note: There is no check that requires model names and aliases to be unique. In the case of a collision, whatever model (or alias) is listed first in the this returned list will be instantiated.
- Returns:
A list of PretrainedModelInfo entries
- classmethod restore_from(
- restore_path: str,
- override_config_path: str | None = None,
- map_location: device | None = None,
Restores model instance (weights and configuration) from a .nemo file
- Parameters:
restore_path – path to .nemo file from which model should be instantiated
override_config_path – path to a yaml config that will override the internal config file or an OmegaConf / DictConfig object representing the model config.
map_location – Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise.
strict – Passed to load_state_dict. By default True
return_config – If set to true, will return just the underlying config of the restored model as an OmegaConf DictConfig object without instantiating the model.
trainer – An optional Trainer object, passed to the model constructor.
save_restore_connector – An optional SaveRestoreConnector object that defines the implementation of the restore_from() method.
- save_to(save_path: str)[source]#
- Saves model instance (weights and configuration) into EFF archive or .
You can use “restore_from” method to fully restore instance from .nemo file.
- .nemo file is an archive (tar.gz) with the following:
model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model’s constructor model_wights.chpt - model checkpoint
- Parameters:
save_path – Path to .nemo file where model instance should be saved
- property verbose: bool#
- class nemo.collections.asr.models.SortformerEncLabelModel(
- cfg: DictConfig,
- trainer: Trainer = None,
Bases:
ModelPT,ExportableEncDecModel,SpkDiarizationMixinEncoder class for Sortformer diarization model. Model class creates training, validation methods for setting up data performing model forward pass.
- This model class expects config dict for:
preprocessor
Transformer Encoder
FastConformer Encoder
Sortformer Modules
- add_rttms_mask_mats(
- rttms_mask_mats,
- device: device,
Check if the rttms_mask_mats is empty then add it to the list
- Parameters:
rttms_mask_mats (List[torch.Tensor]) – List of PyTorch tensors containing the rttms mask matrices.
- diarize(
- audio: str | List[str] | ndarray | DataLoader,
- sample_rate: int | None = None,
- batch_size: int = 1,
- include_tensor_outputs: bool = False,
- postprocessing_yaml: str | None = None,
- num_workers: int = 0,
- verbose: bool = True,
- override_config: DiarizeConfig | None = None,
One-click runner function for diarization.
- Parameters:
audio – (a single or list) of paths to audio files or path to a manifest file.
batch_size – (int) Batch size to use during inference. Bigger will result in better throughput performance but would use more memory.
include_tensor_outputs – (bool) Include raw speaker activity probabilities to the output. See Returns: for more details.
postprocessing_yaml – Optional(str) Path to .yaml file with postprocessing parameters.
num_workers – (int) Number of workers for DataLoader.
verbose – (bool) Whether to display tqdm progress bar.
override_config – (Optional[DiarizeConfig]) A config to override the default config.
- Returns:
A list of lists of speech segments with a corresponding speaker index, in format “[begin_seconds, end_seconds, speaker_index]”. If include_tensor_outputs is True: A tuple of the above list and list of tensors of raw speaker activity probabilities.
- Return type:
If include_tensor_outputs is False
- forward(audio_signal, audio_signal_length)[source]#
Forward pass for training and inference.
- Parameters:
audio_signal (torch.Tensor) – Tensor containing audio waveform Shape: (batch_size, num_samples)
audio_signal_length (torch.Tensor) – Tensor containing lengths of audio waveforms Shape: (batch_size,)
- Returns:
- Sorted tensor containing predicted speaker labels
Shape: (batch_size, max. diar frame count, num_speakers)
- Return type:
preds (torch.Tensor)
- forward_for_export(
- chunk,
- chunk_lengths,
- spkcache,
- spkcache_lengths,
- fifo,
- fifo_lengths,
This forward pass is for ONNX model export.
- Parameters:
chunk (torch.Tensor) – Tensor containing audio waveform. The term “chunk” refers to the “input buffer” in the speech processing pipeline. The size of chunk (input buffer) determines the latency introduced by buffering. Shape: (batch_size, feature frame count, dimension)
chunk_lengths (torch.Tensor) – Tensor containing lengths of audio waveforms Shape: (batch_size,)
spkcache (torch.Tensor) – Tensor containing speaker cache embeddings from start Shape: (batch_size, spkcache_len, emb_dim)
spkcache_lengths (torch.Tensor) – Tensor containing lengths of speaker cache Shape: (batch_size,)
fifo (torch.Tensor) – Tensor containing embeddings from latest chunks Shape: (batch_size, fifo_len, emb_dim)
fifo_lengths (torch.Tensor) – Tensor containing lengths of FIFO queue embeddings Shape: (batch_size,)
- Returns:
- Sorted tensor containing predicted speaker labels
Shape: (batch_size, max. diar frame count, num_speakers)
- chunk_pre_encode_embs (torch.Tensor): Tensor containing pre-encoded embeddings from the chunk
Shape: (batch_size, num_frames, emb_dim)
- chunk_pre_encode_lengths (torch.Tensor): Tensor containing lengths of pre-encoded embeddings
from the chunk (=input buffer). Shape: (batch_size,)
- Return type:
spkcache_fifo_chunk_preds (torch.Tensor)
- forward_infer(emb_seq, emb_seq_length)[source]#
The main forward pass for diarization for offline diarization inference.
- Parameters:
emb_seq (torch.Tensor) – Tensor containing FastConformer encoder states (embedding vectors). Shape: (batch_size, diar_frame_count, emb_dim)
emb_seq_length (torch.Tensor) – Tensor containing lengths of FastConformer encoder states. Shape: (batch_size,)
- Returns:
- Sorted tensor containing Sigmoid values for predicted speaker labels.
Shape: (batch_size, diar_frame_count, num_speakers)
- Return type:
preds (torch.Tensor)
- forward_streaming(
- processed_signal,
- processed_signal_length,
The main forward pass for diarization inference in streaming mode.
- Parameters:
processed_signal (torch.Tensor) – Tensor containing audio waveform Shape: (batch_size, num_samples)
processed_signal_length (torch.Tensor) – Tensor containing lengths of audio waveforms Shape: (batch_size,)
- Returns:
- Tensor containing predicted speaker labels for the current chunk
and all previous chunks Shape: (batch_size, pred_len, num_speakers)
- Return type:
total_preds (torch.Tensor)
- forward_streaming_step(
- processed_signal,
- processed_signal_length,
- streaming_state,
- total_preds,
- drop_extra_pre_encoded=0,
- left_offset=0,
- right_offset=0,
One-step forward pass for diarization inference in streaming mode.
- Parameters:
processed_signal (torch.Tensor) – Tensor containing audio waveform Shape: (batch_size, num_samples)
processed_signal_length (torch.Tensor) – Tensor containing lengths of audio waveforms Shape: (batch_size,)
streaming_state (SortformerStreamingState) –
- Tensor variables that contain the streaming state of the model.
Find more details in the SortformerStreamingState class in sortformer_modules.py.
- spkcache#
Speaker cache to store embeddings from start
- Type:
- spkcache_lengths#
Lengths of the speaker cache
- Type:
- spkcache_preds#
The speaker predictions for the speaker cache parts
- Type:
- fifo#
FIFO queue to save the embedding from the latest chunks
- Type:
- fifo_lengths#
Lengths of the FIFO queue
- Type:
- fifo_preds#
The speaker predictions for the FIFO queue parts
- Type:
- spk_perm#
Speaker permutation information for the speaker cache
- Type:
total_preds (torch.Tensor) – Tensor containing total predicted speaker activity probabilities Shape: (batch_size, cumulative pred length, num_speakers)
left_offset (int) – left offset for the current chunk
right_offset (int) – right offset for the current chunk
- Returns:
- Tensor variables that contain the updated streaming state of the model from
this function call.
- total_preds (torch.Tensor):
Tensor containing the updated total predicted speaker activity probabilities. Shape: (batch_size, cumulative pred length, num_speakers)
- Return type:
streaming_state (SortformerStreamingState)
- frontend_encoder(
- processed_signal,
- processed_signal_length,
- bypass_pre_encode: bool = False,
Generate encoder outputs from frontend encoder.
- Parameters:
processed_signal (torch.Tensor) – tensor containing audio-feature (mel spectrogram, mfcc, etc.).
processed_signal_length (torch.Tensor) – tensor containing lengths of audio signal in integers.
- Returns:
tensor containing encoder outputs. emb_seq_length (torch.Tensor): tensor containing lengths of encoder outputs.
- Return type:
emb_seq (torch.Tensor)
- property input_names#
- property input_types: Dict[str, NeuralType] | None#
Define these to enable input neural type checks
- classmethod list_available_models() List[PretrainedModelInfo][source]#
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA’s NGC cloud.
- Returns:
List of available pre-trained models.
- multi_validation_epoch_end(
- outputs: list,
- dataloader_idx: int = 0,
Adds support for multiple validation datasets. Should be overriden by subclass, so as to obtain appropriate logs for each of the dataloaders.
- Parameters:
outputs – Same as that provided by LightningModule.on_validation_epoch_end() for a single dataloader.
dataloader_idx – int representing the index of the dataloader.
- Returns:
A dictionary of values, optionally containing a sub-dict log, such that the values in the log will be pre-pended by the dataloader prefix.
- on_validation_epoch_end() dict[str, dict[str, Tensor]] | None[source]#
Run validation with sync_dist=True.
- oom_safe_feature_extraction(
- input_signal,
- input_signal_length,
This function divides the input signal into smaller sub-batches and processes them sequentially to prevent out-of-memory errors during feature extraction.
- Parameters:
input_signal (torch.Tensor) – The input audio signal.
input_signal_length (torch.Tensor) – The lengths of the input audio signals.
- Returns:
A tuple of
(processed_signal, processed_signal_length)whereprocessed_signalis the aggregated audio signal tensor (length matches original batch size) andprocessed_signal_lengthcontains the lengths of the processed signals.
- property output_names#
- property output_types: Dict[str, NeuralType]#
Define these to enable output neural type checks
- process_signal(
- audio_signal,
- audio_signal_length,
Extract audio features from time-series signal for further processing in the model.
This function performs the following steps: 1. Moves the audio signal to the correct device. 2. Normalizes the time-series audio signal. 3. Extrac audio feature from from the time-series audio signal using the model’s preprocessor.
- Parameters:
audio_signal (torch.Tensor) – The input audio signal. Shape: (batch_size, num_samples)
audio_signal_length (torch.Tensor) – The length of each audio signal in the batch. Shape: (batch_size,)
- Returns:
- The preprocessed audio signal.
Shape: (batch_size, num_features, num_frames)
- processed_signal_length (torch.Tensor): The length of each processed signal.
Shape: (batch_size,)
- Return type:
processed_signal (torch.Tensor)
- setup_test_data(
- test_data_config: DictConfig | Dict | None,
(Optionally) Setups data loader to be used in test
- Parameters:
test_data_layer_config – test data layer parameters.
Returns:
- setup_training_data(
- train_data_config: DictConfig | Dict | None,
Setups data loader to be used in training
- Parameters:
train_data_layer_config – training data layer parameters.
Returns:
- setup_validation_data(
- val_data_layer_config: DictConfig | Dict | None,
Setups data loader to be used in validation :param val_data_layer_config: validation data layer parameters.
Returns:
- test_batch()[source]#
Perform batch testing on the model.
This method iterates through the test data loader, making predictions for each batch, and calculates various evaluation metrics. It handles both single and multi-sample batches.
- test_step(
- batch: list,
- batch_idx: int,
- dataloader_idx: int = 0,
Performs a single validation step.
This method processes a batch of data during the validation phase. It forward passes the audio signal through the model, computes various validation metrics, and stores these metrics for later aggregation.
- Parameters:
batch (list) – A list containing the following elements: - audio_signal (torch.Tensor): The input audio signal. - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. - targets (torch.Tensor): The target labels for the batch. - target_lens (torch.Tensor): The length of each target sequence in the batch.
batch_idx (int) – The index of the current batch.
dataloader_idx (int, optional) – The index of the dataloader in case of multiple validation dataloaders. Defaults to 0.
- Returns:
A dictionary containing various validation metrics for this batch.
- Return type:
dict
- training_step(
- batch: list,
- batch_idx: int,
Performs a single training step.
- Parameters:
batch (list) – A list containing the following elements: - audio_signal (torch.Tensor): The input audio signal in time-series format. - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. - targets (torch.Tensor): The target labels for the batch. - target_lens (torch.Tensor): The length of each target sequence in the batch.
batch_idx (int) – The index of the current batch.
- Returns:
A dictionary containing the ‘loss’ key with the calculated loss value.
- Return type:
(dict)
- validation_step(
- batch: list,
- batch_idx: int,
- dataloader_idx: int = 0,
Performs a single validation step.
This method processes a batch of data during the validation phase. It forward passes the audio signal through the model, computes various validation metrics, and stores these metrics for later aggregation.
- Parameters:
batch (list) – A list containing the following elements: - audio_signal (torch.Tensor): The input audio signal. - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. - targets (torch.Tensor): The target labels for the batch. - target_lens (torch.Tensor): The length of each target sequence in the batch.
batch_idx (int) – The index of the current batch.
dataloader_idx (int, optional) – The index of the dataloader in case of multiple validation dataloaders. Defaults to 0.
- Returns:
A dictionary containing various validation metrics for this batch.
- Return type:
dict
Mixins#
- class nemo.collections.asr.parts.mixins.diarization.SpkDiarizationMixin[source]#
Bases:
ABCAn abstract class for diarize-able models.
Creates a template function diarize() that provides an interface to perform diarization of audio tensors or filepaths.
- diarize(
- audio: str | List[str] | ndarray | List[ndarray] | DataLoader,
- sample_rate: int | None = None,
- batch_size: int = 1,
- include_tensor_outputs: bool = False,
- postprocessing_yaml: str | None = None,
- num_workers: int = 1,
- verbose: bool = False,
- override_config: DiarizeConfig | None = None,
- **config_kwargs,
Takes paths to audio files and returns speaker labels
- diarize_generator(
- audio: str | List[str] | ndarray | List[ndarray] | DataLoader,
- override_config: DiarizeConfig | None,
A generator version of diarize function.