Source code for nemo.collections.asr.models.msdd_models

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
import os
import pickle as pkl
import tempfile
from collections import OrderedDict
from pathlib import Path
from statistics import mode
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig, open_dict
from pyannote.core import Annotation
from pyannote.metrics.diarization import DiarizationErrorRate
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm

from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechMSDDInferDataset, AudioToSpeechMSDDTrainDataset
from nemo.collections.asr.metrics.der import score_labels
from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy
from nemo.collections.asr.models import ClusteringDiarizer
from nemo.collections.asr.models.asr_model import ExportableEncDecModel
from nemo.collections.asr.models.clustering_diarizer import (
    _MODEL_CONFIG_YAML,
    _SPEAKER_MODEL,
    _VAD_MODEL,
    get_available_model_names,
)
from nemo.collections.asr.models.configs.diarizer_config import NeuralDiarizerInferenceConfig
from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
from nemo.collections.asr.parts.utils.speaker_utils import (
    audio_rttm_map,
    get_embs_and_timestamps,
    get_id_tup_dict,
    get_scale_mapping_argmat,
    get_uniq_id_list_from_manifest,
    labels_to_pyannote_object,
    make_rttm_with_overlap,
    parse_scale_configs,
    rttm_to_labels,
)
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType
from nemo.core.neural_types.elements import ProbsType
from nemo.utils import logging

try:
    from torch.cuda.amp import autocast
except ImportError:
    from contextlib import contextmanager

    @contextmanager
    def autocast(enabled=None):
        yield


__all__ = ['EncDecDiarLabelModel', 'ClusterEmbedding', 'NeuralDiarizer']


[docs]class EncDecDiarLabelModel(ModelPT, ExportableEncDecModel): """ Encoder decoder class for multiscale diarization decoder (MSDD). Model class creates training, validation methods for setting up data performing model forward pass. This model class expects config dict for: * preprocessor * msdd_model * speaker_model """
[docs] @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ result = [] model = PretrainedModelInfo( pretrained_model_name="diar_msdd_telephonic", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/diar_msdd_telephonic/versions/1.0.1/files/diar_msdd_telephonic.nemo", description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:diar_msdd_telephonic", ) result.append(model) return result
def __init__(self, cfg: DictConfig, trainer: Trainer = None): """ Initialize an MSDD model and the specified speaker embedding model. In this init function, training and validation datasets are prepared. """ self._trainer = trainer if trainer else None self.cfg_msdd_model = cfg if self._trainer: self._init_segmentation_info() self.world_size = trainer.num_nodes * trainer.num_devices self.emb_batch_size = self.cfg_msdd_model.emb_batch_size self.pairwise_infer = False else: self.world_size = 1 self.pairwise_infer = True super().__init__(cfg=self.cfg_msdd_model, trainer=trainer) window_length_in_sec = self.cfg_msdd_model.diarizer.speaker_embeddings.parameters.window_length_in_sec if isinstance(window_length_in_sec, int) or len(window_length_in_sec) <= 1: raise ValueError("window_length_in_sec should be a list containing multiple segment (window) lengths") else: self.cfg_msdd_model.scale_n = len(window_length_in_sec) self.cfg_msdd_model.msdd_module.scale_n = self.cfg_msdd_model.scale_n self.scale_n = self.cfg_msdd_model.scale_n self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(self.cfg_msdd_model.preprocessor) self.frame_per_sec = int(1 / self.preprocessor._cfg.window_stride) self.msdd = EncDecDiarLabelModel.from_config_dict(self.cfg_msdd_model.msdd_module) if trainer is not None: self._init_speaker_model() self.add_speaker_model_config(cfg) else: self.msdd._speaker_model = EncDecSpeakerLabelModel.from_config_dict(cfg.speaker_model_cfg) # Call `self.save_hyperparameters` in modelPT.py again since cfg should contain speaker model's config. self.save_hyperparameters("cfg") self.loss = instantiate(self.cfg_msdd_model.loss) self._accuracy_test = MultiBinaryAccuracy() self._accuracy_train = MultiBinaryAccuracy() self._accuracy_valid = MultiBinaryAccuracy()
[docs] def add_speaker_model_config(self, cfg): """ Add config dictionary of the speaker model to the model's config dictionary. This is required to save and load speaker model with MSDD model. Args: cfg (DictConfig): DictConfig type variable that conatains hyperparameters of MSDD model. """ with open_dict(cfg): cfg_cp = copy.copy(self.msdd._speaker_model.cfg) cfg.speaker_model_cfg = cfg_cp del cfg.speaker_model_cfg.train_ds del cfg.speaker_model_cfg.validation_ds
def _init_segmentation_info(self): """Initialize segmentation settings: window, shift and multiscale weights. """ self._diarizer_params = self.cfg_msdd_model.diarizer self.multiscale_args_dict = parse_scale_configs( self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec, self._diarizer_params.speaker_embeddings.parameters.shift_length_in_sec, self._diarizer_params.speaker_embeddings.parameters.multiscale_weights, ) def _init_speaker_model(self): """ Initialize speaker embedding model with model name or path passed through config. Note that speaker embedding model is loaded to `self.msdd` to enable multi-gpu and multi-node training. In addition, speaker embedding model is also saved with msdd model when `.ckpt` files are saved. """ model_path = self.cfg_msdd_model.diarizer.speaker_embeddings.model_path self._diarizer_params = self.cfg_msdd_model.diarizer if not torch.cuda.is_available(): rank_id = torch.device('cpu') elif self._trainer: rank_id = torch.device(self._trainer.global_rank) else: rank_id = None if model_path is not None and model_path.endswith('.nemo'): self.msdd._speaker_model = EncDecSpeakerLabelModel.restore_from(model_path, map_location=rank_id) logging.info("Speaker Model restored locally from {}".format(model_path)) elif model_path.endswith('.ckpt'): self._speaker_model = EncDecSpeakerLabelModel.load_from_checkpoint(model_path, map_location=rank_id) logging.info("Speaker Model restored locally from {}".format(model_path)) else: if model_path not in get_available_model_names(EncDecSpeakerLabelModel): logging.warning( "requested {} model name not available in pretrained models, instead".format(model_path) ) model_path = "titanet_large" logging.info("Loading pretrained {} model from NGC".format(model_path)) self.msdd._speaker_model = EncDecSpeakerLabelModel.from_pretrained( model_name=model_path, map_location=rank_id ) self._speaker_params = self.cfg_msdd_model.diarizer.speaker_embeddings.parameters def __setup_dataloader_from_config(self, config): featurizer = WaveformFeaturizer( sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=None ) if 'manifest_filepath' in config and config['manifest_filepath'] is None: logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") return None dataset = AudioToSpeechMSDDTrainDataset( manifest_filepath=config.manifest_filepath, emb_dir=config.emb_dir, multiscale_args_dict=self.multiscale_args_dict, soft_label_thres=config.soft_label_thres, featurizer=featurizer, window_stride=self.cfg_msdd_model.preprocessor.window_stride, emb_batch_size=config.emb_batch_size, pairwise_infer=False, global_rank=self._trainer.global_rank, ) self.data_collection = dataset.collection collate_ds = dataset collate_fn = collate_ds.msdd_train_collate_fn batch_size = config['batch_size'] return torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=config.get('drop_last', False), shuffle=False, num_workers=config.get('num_workers', 0), pin_memory=config.get('pin_memory', False), ) def __setup_dataloader_from_config_infer( self, config: DictConfig, emb_dict: dict, emb_seq: dict, clus_label_dict: dict, pairwise_infer=False ): shuffle = config.get('shuffle', False) if 'manifest_filepath' in config and config['manifest_filepath'] is None: logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") return None dataset = AudioToSpeechMSDDInferDataset( manifest_filepath=config['manifest_filepath'], emb_dict=emb_dict, clus_label_dict=clus_label_dict, emb_seq=emb_seq, soft_label_thres=config.soft_label_thres, seq_eval_mode=config.seq_eval_mode, window_stride=self._cfg.preprocessor.window_stride, use_single_scale_clus=False, pairwise_infer=pairwise_infer, ) self.data_collection = dataset.collection collate_ds = dataset collate_fn = collate_ds.msdd_infer_collate_fn batch_size = config['batch_size'] return torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=config.get('drop_last', False), shuffle=shuffle, num_workers=config.get('num_workers', 0), pin_memory=config.get('pin_memory', False), )
[docs] def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): self._train_dl = self.__setup_dataloader_from_config(config=train_data_config,)
[docs] def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config,)
[docs] def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): if self.pairwise_infer: self._test_dl = self.__setup_dataloader_from_config_infer( config=test_data_config, emb_dict=self.emb_sess_test_dict, emb_seq=self.emb_seq_test, clus_label_dict=self.clus_test_label_dict, pairwise_infer=self.pairwise_infer, )
[docs] def setup_multiple_test_data(self, test_data_config): """ MSDD does not use multiple_test_data template. This function is a placeholder for preventing error. """ return None
[docs] def test_dataloader(self): if self._test_dl is not None: return self._test_dl
@property def input_types(self) -> Optional[Dict[str, NeuralType]]: if hasattr(self.preprocessor, '_sample_rate'): audio_eltype = AudioSignal(freq=self.preprocessor._sample_rate) else: audio_eltype = AudioSignal() return { "features": NeuralType(('B', 'T'), audio_eltype), "feature_length": NeuralType(('B',), LengthsType()), "ms_seg_timestamps": NeuralType(('B', 'C', 'T', 'D'), LengthsType()), "ms_seg_counts": NeuralType(('B', 'C'), LengthsType()), "clus_label_index": NeuralType(('B', 'T'), LengthsType()), "scale_mapping": NeuralType(('B', 'C', 'T'), LengthsType()), "targets": NeuralType(('B', 'T', 'C'), ProbsType()), } @property def output_types(self) -> Dict[str, NeuralType]: return OrderedDict( { "probs": NeuralType(('B', 'T', 'C'), ProbsType()), "scale_weights": NeuralType(('B', 'T', 'C', 'D'), ProbsType()), } )
[docs] def get_ms_emb_seq( self, embs: torch.Tensor, scale_mapping: torch.Tensor, ms_seg_counts: torch.Tensor ) -> torch.Tensor: """ Reshape the given tensor and organize the embedding sequence based on the original sequence counts. Repeat the embeddings according to the scale_mapping information so that the final embedding sequence has the identical length for all scales. Args: embs (Tensor): Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details. Shape: (Total number of segments in the batch, emb_dim) scale_mapping (Tensor): The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale segment index which has the closest center distance with (n+1)-th segment in the base scale. Example: scale_mapping_argmat[2][101] = 85 In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with 102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since multiple base scale segments (since the base scale has the shortest length) fall into the range of the longer segments. At the same time, each row contains N numbers of indices where N is number of segments in the base-scale (i.e., the finest scale). Shape: (batch_size, scale_n, self.diar_window_length) ms_seg_counts (Tensor): Cumulative sum of the number of segments in each scale. This information is needed to reconstruct the multi-scale input matrix during forward propagating. Example: `batch_size=3, scale_n=6, emb_dim=192` ms_seg_counts = [[8, 9, 12, 16, 25, 51], [11, 13, 14, 17, 25, 51], [ 9, 9, 11, 16, 23, 50]] In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without zero-padding. Returns: ms_emb_seq (Tensor): Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated, while shorter scales are more frequently repeated following the scale mapping tensor. """ scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0] split_emb_tup = torch.split(embs, ms_seg_counts.view(-1).tolist(), dim=0) batch_emb_list = [split_emb_tup[i : i + scale_n] for i in range(0, len(split_emb_tup), scale_n)] ms_emb_seq_list = [] for batch_idx in range(batch_size): feats_list = [] for scale_index in range(scale_n): repeat_mat = scale_mapping[batch_idx][scale_index] feats_list.append(batch_emb_list[batch_idx][scale_index][repeat_mat, :]) repp = torch.stack(feats_list).permute(1, 0, 2) ms_emb_seq_list.append(repp) ms_emb_seq = torch.stack(ms_emb_seq_list) return ms_emb_seq
@torch.no_grad() def get_cluster_avg_embs_model( self, embs: torch.Tensor, clus_label_index: torch.Tensor, ms_seg_counts: torch.Tensor, scale_mapping ) -> torch.Tensor: """ Calculate the cluster-average speaker embedding based on the ground-truth speaker labels (i.e., cluster labels). Args: embs (Tensor): Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details. Shape: (Total number of segments in the batch, emb_dim) clus_label_index (Tensor): Merged ground-truth cluster labels from all scales with zero-padding. Each scale's index can be retrieved by using segment index in `ms_seg_counts`. Shape: (batch_size, maximum total segment count among the samples in the batch) ms_seg_counts (Tensor): Cumulative sum of the number of segments in each scale. This information is needed to reconstruct multi-scale input tensors during forward propagating. Example: `batch_size=3, scale_n=6, emb_dim=192` ms_seg_counts = [[8, 9, 12, 16, 25, 51], [11, 13, 14, 17, 25, 51], [ 9, 9, 11, 16, 23, 50]] Counts of merged segments: (121, 131, 118) embs has shape of (370, 192) clus_label_index has shape of (3, 131) Shape: (batch_size, scale_n) Returns: ms_avg_embs (Tensor): Multi-scale cluster-average speaker embedding vectors. These embedding vectors are used as reference for each speaker to predict the speaker label for the given multi-scale embedding sequences. Shape: (batch_size, scale_n, emb_dim, self.num_spks_per_model) """ scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0] split_emb_tup = torch.split(embs, ms_seg_counts.view(-1).tolist(), dim=0) batch_emb_list = [split_emb_tup[i : i + scale_n] for i in range(0, len(split_emb_tup), scale_n)] ms_avg_embs_list = [] for batch_idx in range(batch_size): oracle_clus_idx = clus_label_index[batch_idx] max_seq_len = sum(ms_seg_counts[batch_idx]) clus_label_index_batch = torch.split(oracle_clus_idx[:max_seq_len], ms_seg_counts[batch_idx].tolist()) session_avg_emb_set_list = [] for scale_index in range(scale_n): spk_set_list = [] for idx in range(self.cfg_msdd_model.max_num_of_spks): _where = (clus_label_index_batch[scale_index] == idx).clone().detach() if not torch.any(_where): avg_emb = torch.zeros(self.msdd._speaker_model._cfg.decoder.emb_sizes).to(embs.device) else: avg_emb = torch.mean(batch_emb_list[batch_idx][scale_index][_where], dim=0) spk_set_list.append(avg_emb) session_avg_emb_set_list.append(torch.stack(spk_set_list)) session_avg_emb_set = torch.stack(session_avg_emb_set_list) ms_avg_embs_list.append(session_avg_emb_set) ms_avg_embs = torch.stack(ms_avg_embs_list).permute(0, 1, 3, 2) ms_avg_embs = ms_avg_embs.float().detach().to(embs.device) assert ( not ms_avg_embs.requires_grad ), "ms_avg_embs.requires_grad = True. ms_avg_embs should be detached from the torch graph." return ms_avg_embs @torch.no_grad() def get_ms_mel_feat( self, processed_signal: torch.Tensor, processed_signal_len: torch.Tensor, ms_seg_timestamps: torch.Tensor, ms_seg_counts: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Load acoustic feature from audio segments for each scale and save it into a torch.tensor matrix. In addition, create variables containing the information of the multiscale subsegmentation information. Note: `self.emb_batch_size` determines the number of embedding tensors attached to the computational graph. If `self.emb_batch_size` is greater than 0, speaker embedding models are simultaneosly trained. Due to the constrant of GPU memory size, only a subset of embedding tensors can be attached to the computational graph. By default, the graph-attached embeddings are selected randomly by `torch.randperm`. Default value of `self.emb_batch_size` is 0. Args: processed_signal (Tensor): Zero-padded Feature input. Shape: (batch_size, feat_dim, the longest feature sequence length) processed_signal_len (Tensor): The actual legnth of feature input without zero-padding. Shape: (batch_size,) ms_seg_timestamps (Tensor): Timestamps of the base-scale segments. Shape: (batch_size, scale_n, number of base-scale segments, self.num_spks_per_model) ms_seg_counts (Tensor): Cumulative sum of the number of segments in each scale. This information is needed to reconstruct the multi-scale input matrix during forward propagating. Shape: (batch_size, scale_n) Returns: ms_mel_feat (Tensor): Feature input stream split into the same length. Shape: (total number of segments, feat_dim, self.frame_per_sec * the-longest-scale-length) ms_mel_feat_len (Tensor): The actual length of feature without zero-padding. Shape: (total number of segments,) seq_len (Tensor): The length of the input embedding sequences. Shape: (total number of segments,) detach_ids (tuple): Tuple containing both detached embeding indices and attached embedding indices """ device = processed_signal.device _emb_batch_size = min(self.emb_batch_size, ms_seg_counts.sum().item()) feat_dim = self.preprocessor._cfg.features max_sample_count = int(self.multiscale_args_dict["scale_dict"][0][0] * self.frame_per_sec) ms_mel_feat_len_list, sequence_lengths_list, ms_mel_feat_list = [], [], [] total_seg_count = torch.sum(ms_seg_counts) batch_size = processed_signal.shape[0] for batch_idx in range(batch_size): for scale_idx in range(self.scale_n): scale_seg_num = ms_seg_counts[batch_idx][scale_idx] for k, (stt, end) in enumerate(ms_seg_timestamps[batch_idx][scale_idx][:scale_seg_num]): stt, end = int(stt.detach().item()), int(end.detach().item()) end = min(end, stt + max_sample_count) _features = torch.zeros(feat_dim, max_sample_count).to(torch.float32).to(device) _features[:, : (end - stt)] = processed_signal[batch_idx][:, stt:end] ms_mel_feat_list.append(_features) ms_mel_feat_len_list.append(end - stt) sequence_lengths_list.append(ms_seg_counts[batch_idx][-1]) ms_mel_feat = torch.stack(ms_mel_feat_list).to(device) ms_mel_feat_len = torch.tensor(ms_mel_feat_len_list).to(device) seq_len = torch.tensor(sequence_lengths_list).to(device) if _emb_batch_size == 0: attached, _emb_batch_size = torch.tensor([]), 0 detached = torch.arange(total_seg_count) else: torch.manual_seed(self._trainer.current_epoch) attached = torch.randperm(total_seg_count)[:_emb_batch_size] detached = torch.randperm(total_seg_count)[_emb_batch_size:] detach_ids = (attached, detached) return ms_mel_feat, ms_mel_feat_len, seq_len, detach_ids
[docs] def forward_infer(self, input_signal, input_signal_length, emb_vectors, targets): """ Wrapper function for inference case. """ preds, scale_weights = self.msdd( ms_emb_seq=input_signal, length=input_signal_length, ms_avg_embs=emb_vectors, targets=targets ) return preds, scale_weights
[docs] @typecheck() def forward( self, features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets ): processed_signal, processed_signal_len = self.msdd._speaker_model.preprocessor( input_signal=features, length=feature_length ) audio_signal, audio_signal_len, sequence_lengths, detach_ids = self.get_ms_mel_feat( processed_signal, processed_signal_len, ms_seg_timestamps, ms_seg_counts ) # For detached embeddings with torch.no_grad(): self.msdd._speaker_model.eval() logits, embs_d = self.msdd._speaker_model.forward_for_export( processed_signal=audio_signal[detach_ids[1]], processed_signal_len=audio_signal_len[detach_ids[1]] ) embs = torch.zeros(audio_signal.shape[0], embs_d.shape[1]).to(embs_d.device) embs[detach_ids[1], :] = embs_d.detach() # For attached embeddings self.msdd._speaker_model.train() if len(detach_ids[0]) > 1: logits, embs_a = self.msdd._speaker_model.forward_for_export( processed_signal=audio_signal[detach_ids[0]], processed_signal_len=audio_signal_len[detach_ids[0]] ) embs[detach_ids[0], :] = embs_a ms_emb_seq = self.get_ms_emb_seq(embs, scale_mapping, ms_seg_counts) ms_avg_embs = self.get_cluster_avg_embs_model(embs, clus_label_index, ms_seg_counts, scale_mapping) preds, scale_weights = self.msdd( ms_emb_seq=ms_emb_seq, length=sequence_lengths, ms_avg_embs=ms_avg_embs, targets=targets ) return preds, scale_weights
[docs] def training_step(self, batch: list, batch_idx: int): features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets = batch sequence_lengths = torch.tensor([x[-1] for x in ms_seg_counts.detach()]) preds, _ = self.forward( features=features, feature_length=feature_length, ms_seg_timestamps=ms_seg_timestamps, ms_seg_counts=ms_seg_counts, clus_label_index=clus_label_index, scale_mapping=scale_mapping, targets=targets, ) loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths) self._accuracy_train(preds, targets, sequence_lengths) torch.cuda.empty_cache() f1_acc = self._accuracy_train.compute() self.log('loss', loss, sync_dist=True) self.log('learning_rate', self._optimizer.param_groups[0]['lr'], sync_dist=True) self.log('train_f1_acc', f1_acc, sync_dist=True) self._accuracy_train.reset() return {'loss': loss}
[docs] def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets = batch sequence_lengths = torch.tensor([x[-1] for x in ms_seg_counts]) preds, _ = self.forward( features=features, feature_length=feature_length, ms_seg_timestamps=ms_seg_timestamps, ms_seg_counts=ms_seg_counts, clus_label_index=clus_label_index, scale_mapping=scale_mapping, targets=targets, ) loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths) self._accuracy_valid(preds, targets, sequence_lengths) f1_acc = self._accuracy_valid.compute() self.log('val_loss', loss, sync_dist=True) self.log('val_f1_acc', f1_acc, sync_dist=True) return { 'val_loss': loss, 'val_f1_acc': f1_acc, }
[docs] def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() f1_acc = self._accuracy_valid.compute() self._accuracy_valid.reset() self.log('val_loss', val_loss_mean, sync_dist=True) self.log('val_f1_acc', f1_acc, sync_dist=True) return { 'val_loss': val_loss_mean, 'val_f1_acc': f1_acc, }
[docs] def multi_test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0): test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() f1_acc = self._accuracy_test.compute() self._accuracy_test.reset() self.log('test_f1_acc', f1_acc, sync_dist=True) return { 'test_loss': test_loss_mean, 'test_f1_acc': f1_acc, }
[docs] def compute_accuracies(self): """ Calculate F1 score and accuracy of the predicted sigmoid values. Returns: f1_score (float): F1 score of the estimated diarized speaker label sequences. simple_acc (float): Accuracy of predicted speaker labels: (total # of correct labels)/(total # of sigmoid values) """ f1_score = self._accuracy_test.compute() num_correct = torch.sum(self._accuracy_test.true.bool()) total_count = torch.prod(torch.tensor(self._accuracy_test.targets.shape)) simple_acc = num_correct / total_count return f1_score, simple_acc
class ClusterEmbedding(torch.nn.Module): """ This class is built for calculating cluster-average embeddings, segmentation and load/save of the estimated cluster labels. The methods in this class is used for the inference of MSDD models. Args: cfg_diar_infer (DictConfig): Config dictionary from diarization inference YAML file cfg_msdd_model (DictConfig): Config dictionary from MSDD model checkpoint file Class Variables: self.cfg_diar_infer (DictConfig): Config dictionary from diarization inference YAML file cfg_msdd_model (DictConfig): Config dictionary from MSDD model checkpoint file self._speaker_model (class `EncDecSpeakerLabelModel`): This is a placeholder for class instance of `EncDecSpeakerLabelModel` self.scale_window_length_list (list): List containing the window lengths (i.e., scale length) of each scale. self.scale_n (int): Number of scales for multi-scale clustering diarizer self.base_scale_index (int): The index of the base-scale which is the shortest scale among the given multiple scales """ def __init__( self, cfg_diar_infer: DictConfig, cfg_msdd_model: DictConfig, speaker_model: Optional[EncDecSpeakerLabelModel] ): super().__init__() self.cfg_diar_infer = cfg_diar_infer self._cfg_msdd = cfg_msdd_model self._speaker_model = speaker_model self.scale_window_length_list = list( self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.window_length_in_sec ) self.scale_n = len(self.scale_window_length_list) self.base_scale_index = len(self.scale_window_length_list) - 1 self.clus_diar_model = ClusteringDiarizer(cfg=self.cfg_diar_infer, speaker_model=self._speaker_model) def prepare_cluster_embs_infer(self): """ Launch clustering diarizer to prepare embedding vectors and clustering results. """ self.max_num_speakers = self.cfg_diar_infer.diarizer.clustering.parameters.max_num_speakers self.emb_sess_test_dict, self.emb_seq_test, self.clus_test_label_dict, _ = self.run_clustering_diarizer( self._cfg_msdd.test_ds.manifest_filepath, self._cfg_msdd.test_ds.emb_dir ) def assign_labels_to_longer_segs(self, base_clus_label_dict: Dict, session_scale_mapping_dict: Dict): """ In multi-scale speaker diarization system, clustering result is solely based on the base-scale (the shortest scale). To calculate cluster-average speaker embeddings for each scale that are longer than the base-scale, this function assigns clustering results for the base-scale to the longer scales by measuring the distance between subsegment timestamps in the base-scale and non-base-scales. Args: base_clus_label_dict (dict): Dictionary containing clustering results for base-scale segments. Indexed by `uniq_id` string. session_scale_mapping_dict (dict): Dictionary containing multiscale mapping information for each session. Indexed by `uniq_id` string. Returns: all_scale_clus_label_dict (dict): Dictionary containing clustering labels of all scales. Indexed by scale_index in integer format. """ all_scale_clus_label_dict = {scale_index: {} for scale_index in range(self.scale_n)} for uniq_id, uniq_scale_mapping_dict in session_scale_mapping_dict.items(): base_scale_clus_label = np.array([x[-1] for x in base_clus_label_dict[uniq_id]]) all_scale_clus_label_dict[self.base_scale_index][uniq_id] = base_scale_clus_label for scale_index in range(self.scale_n - 1): new_clus_label = [] assert ( uniq_scale_mapping_dict[scale_index].shape[0] == base_scale_clus_label.shape[0] ), "The number of base scale labels does not match the segment numbers in uniq_scale_mapping_dict" max_index = max(uniq_scale_mapping_dict[scale_index]) for seg_idx in range(max_index + 1): if seg_idx in uniq_scale_mapping_dict[scale_index]: seg_clus_label = mode(base_scale_clus_label[uniq_scale_mapping_dict[scale_index] == seg_idx]) else: seg_clus_label = 0 if len(new_clus_label) == 0 else new_clus_label[-1] new_clus_label.append(seg_clus_label) all_scale_clus_label_dict[scale_index][uniq_id] = new_clus_label return all_scale_clus_label_dict def get_base_clus_label_dict(self, clus_labels: List[str], emb_scale_seq_dict: Dict[int, dict]): """ Retrieve base scale clustering labels from `emb_scale_seq_dict`. Args: clus_labels (list): List containing cluster results generated by clustering diarizer. emb_scale_seq_dict (dict): Dictionary containing multiscale embedding input sequences. Returns: base_clus_label_dict (dict): Dictionary containing start and end of base scale segments and its cluster label. Indexed by `uniq_id`. emb_dim (int): Embedding dimension in integer. """ base_clus_label_dict = {key: [] for key in emb_scale_seq_dict[self.base_scale_index].keys()} for line in clus_labels: uniq_id = line.split()[0] label = int(line.split()[-1].split('_')[-1]) stt, end = [round(float(x), 2) for x in line.split()[1:3]] base_clus_label_dict[uniq_id].append([stt, end, label]) emb_dim = emb_scale_seq_dict[0][uniq_id][0].shape[0] return base_clus_label_dict, emb_dim def get_cluster_avg_embs( self, emb_scale_seq_dict: Dict, clus_labels: List, speaker_mapping_dict: Dict, session_scale_mapping_dict: Dict ): """ MSDD requires cluster-average speaker embedding vectors for each scale. This function calculates an average embedding vector for each cluster (speaker) and each scale. Args: emb_scale_seq_dict (dict): Dictionary containing embedding sequence for each scale. Keys are scale index in integer. clus_labels (list): Clustering results from clustering diarizer including all the sessions provided in input manifest files. speaker_mapping_dict (dict): Speaker mapping dictionary in case RTTM files are provided. This is mapping between integer based speaker index and speaker ID tokens in RTTM files. Example: {'en_0638': {'speaker_0': 'en_0638_A', 'speaker_1': 'en_0638_B'}, 'en_4065': {'speaker_0': 'en_4065_B', 'speaker_1': 'en_4065_A'}, ...,} session_scale_mapping_dict (dict): Dictionary containing multiscale mapping information for each session. Indexed by `uniq_id` string. Returns: emb_sess_avg_dict (dict): Dictionary containing speaker mapping information and cluster-average speaker embedding vector. Each session-level dictionary is indexed by scale index in integer. output_clus_label_dict (dict): Subegmentation timestamps in float type and Clustering result in integer type. Indexed by `uniq_id` keys. """ self.scale_n = len(emb_scale_seq_dict.keys()) emb_sess_avg_dict = { scale_index: {key: [] for key in emb_scale_seq_dict[self.scale_n - 1].keys()} for scale_index in emb_scale_seq_dict.keys() } output_clus_label_dict, emb_dim = self.get_base_clus_label_dict(clus_labels, emb_scale_seq_dict) all_scale_clus_label_dict = self.assign_labels_to_longer_segs( output_clus_label_dict, session_scale_mapping_dict ) for scale_index in emb_scale_seq_dict.keys(): for uniq_id, _emb_tensor in emb_scale_seq_dict[scale_index].items(): if type(_emb_tensor) == list: emb_tensor = torch.tensor(np.array(_emb_tensor)) else: emb_tensor = _emb_tensor clus_label_list = all_scale_clus_label_dict[scale_index][uniq_id] spk_set = set(clus_label_list) # Create a label array which identifies clustering result for each segment. label_array = torch.Tensor(clus_label_list) avg_embs = torch.zeros(emb_dim, self.max_num_speakers) for spk_idx in spk_set: selected_embs = emb_tensor[label_array == spk_idx] avg_embs[:, spk_idx] = torch.mean(selected_embs, dim=0) if speaker_mapping_dict is not None: inv_map = {clus_key: rttm_key for rttm_key, clus_key in speaker_mapping_dict[uniq_id].items()} else: inv_map = None emb_sess_avg_dict[scale_index][uniq_id] = {'mapping': inv_map, 'avg_embs': avg_embs} return emb_sess_avg_dict, output_clus_label_dict def run_clustering_diarizer(self, manifest_filepath: str, emb_dir: str): """ If no pre-existing data is provided, run clustering diarizer from scratch. This will create scale-wise speaker embedding sequence, cluster-average embeddings, scale mapping and base scale clustering labels. Note that speaker embedding `state_dict` is loaded from the `state_dict` in the provided MSDD checkpoint. Args: manifest_filepath (str): Input manifest file for creating audio-to-RTTM mapping. emb_dir (str): Output directory where embedding files and timestamp files are saved. Returns: emb_sess_avg_dict (dict): Dictionary containing cluster-average embeddings for each session. emb_scale_seq_dict (dict): Dictionary containing embedding tensors which are indexed by scale numbers. base_clus_label_dict (dict): Dictionary containing clustering results. Clustering results are cluster labels for the base scale segments. """ self.cfg_diar_infer.diarizer.manifest_filepath = manifest_filepath self.cfg_diar_infer.diarizer.out_dir = emb_dir # Run ClusteringDiarizer which includes system VAD or oracle VAD. self._out_dir = self.clus_diar_model._diarizer_params.out_dir self.out_rttm_dir = os.path.join(self._out_dir, 'pred_rttms') os.makedirs(self.out_rttm_dir, exist_ok=True) self.clus_diar_model._cluster_params = self.cfg_diar_infer.diarizer.clustering.parameters self.clus_diar_model.multiscale_args_dict[ "multiscale_weights" ] = self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights self.clus_diar_model._diarizer_params.speaker_embeddings.parameters = ( self.cfg_diar_infer.diarizer.speaker_embeddings.parameters ) cluster_params = self.clus_diar_model._cluster_params cluster_params = dict(cluster_params) if isinstance(cluster_params, DictConfig) else cluster_params.dict() clustering_params_str = json.dumps(cluster_params, indent=4) logging.info(f"Multiscale Weights: {self.clus_diar_model.multiscale_args_dict['multiscale_weights']}") logging.info(f"Clustering Parameters: {clustering_params_str}") scores = self.clus_diar_model.diarize(batch_size=self.cfg_diar_infer.batch_size) # If RTTM (ground-truth diarization annotation) files do not exist, scores is None. if scores is not None: metric, speaker_mapping_dict, _ = scores else: metric, speaker_mapping_dict = None, None # Get the mapping between segments in different scales. self._embs_and_timestamps = get_embs_and_timestamps( self.clus_diar_model.multiscale_embeddings_and_timestamps, self.clus_diar_model.multiscale_args_dict ) session_scale_mapping_dict = self.get_scale_map(self._embs_and_timestamps) emb_scale_seq_dict = self.load_emb_scale_seq_dict(emb_dir) clus_labels = self.load_clustering_labels(emb_dir) emb_sess_avg_dict, base_clus_label_dict = self.get_cluster_avg_embs( emb_scale_seq_dict, clus_labels, speaker_mapping_dict, session_scale_mapping_dict ) emb_scale_seq_dict['session_scale_mapping'] = session_scale_mapping_dict return emb_sess_avg_dict, emb_scale_seq_dict, base_clus_label_dict, metric def get_scale_map(self, embs_and_timestamps): """ Save multiscale mapping data into dictionary format. Args: embs_and_timestamps (dict): Dictionary containing embedding tensors and timestamp tensors. Indexed by `uniq_id` string. Returns: session_scale_mapping_dict (dict): Dictionary containing multiscale mapping information for each session. Indexed by `uniq_id` string. """ session_scale_mapping_dict = {} for uniq_id, uniq_embs_and_timestamps in embs_and_timestamps.items(): scale_mapping_dict = get_scale_mapping_argmat(uniq_embs_and_timestamps) session_scale_mapping_dict[uniq_id] = scale_mapping_dict return session_scale_mapping_dict def check_clustering_labels(self, out_dir): """ Check whether the laoded clustering label file is including clustering results for all sessions. This function is used for inference mode of MSDD. Args: out_dir (str): Path to the directory where clustering result files are saved. Returns: file_exists (bool): Boolean that indicates whether clustering result file exists. clus_label_path (str): Path to the clustering label output file. """ clus_label_path = os.path.join( out_dir, 'speaker_outputs', f'subsegments_scale{self.base_scale_index}_cluster.label' ) file_exists = os.path.exists(clus_label_path) if not file_exists: logging.info(f"Clustering label file {clus_label_path} does not exist.") return file_exists, clus_label_path def load_clustering_labels(self, out_dir): """ Load clustering labels generated by clustering diarizer. This function is used for inference mode of MSDD. Args: out_dir (str): Path to the directory where clustering result files are saved. Returns: emb_scale_seq_dict (dict): List containing clustering results in string format. """ file_exists, clus_label_path = self.check_clustering_labels(out_dir) logging.info(f"Loading cluster label file from {clus_label_path}") with open(clus_label_path) as f: clus_labels = f.readlines() return clus_labels def load_emb_scale_seq_dict(self, out_dir): """ Load saved embeddings generated by clustering diarizer. This function is used for inference mode of MSDD. Args: out_dir (str): Path to the directory where embedding pickle files are saved. Returns: emb_scale_seq_dict (dict): Dictionary containing embedding tensors which are indexed by scale numbers. """ window_len_list = list(self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.window_length_in_sec) emb_scale_seq_dict = {scale_index: None for scale_index in range(len(window_len_list))} for scale_index in range(len(window_len_list)): pickle_path = os.path.join( out_dir, 'speaker_outputs', 'embeddings', f'subsegments_scale{scale_index}_embeddings.pkl' ) logging.info(f"Loading embedding pickle file of scale:{scale_index} at {pickle_path}") with open(pickle_path, "rb") as input_file: emb_dict = pkl.load(input_file) for key, val in emb_dict.items(): emb_dict[key] = val emb_scale_seq_dict[scale_index] = emb_dict return emb_scale_seq_dict class NeuralDiarizer(LightningModule): """ Class for inference based on multiscale diarization decoder (MSDD). MSDD requires initializing clustering results from clustering diarizer. Overlap-aware diarizer requires separate RTTM generation and evaluation modules to check the effect of overlap detection in speaker diarization. """ def __init__(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]): super().__init__() self._cfg = cfg # Parameter settings for MSDD model self.use_speaker_model_from_ckpt = cfg.diarizer.msdd_model.parameters.get('use_speaker_model_from_ckpt', True) self.use_clus_as_main = cfg.diarizer.msdd_model.parameters.get('use_clus_as_main', False) self.max_overlap_spks = cfg.diarizer.msdd_model.parameters.get('max_overlap_spks', 2) self.num_spks_per_model = cfg.diarizer.msdd_model.parameters.get('num_spks_per_model', 2) self.use_adaptive_thres = cfg.diarizer.msdd_model.parameters.get('use_adaptive_thres', True) self.max_pred_length = cfg.diarizer.msdd_model.parameters.get('max_pred_length', 0) self.diar_eval_settings = cfg.diarizer.msdd_model.parameters.get( 'diar_eval_settings', [(0.25, True), (0.25, False), (0.0, False)] ) self._init_msdd_model(cfg) self.diar_window_length = cfg.diarizer.msdd_model.parameters.diar_window_length self.msdd_model.cfg = self.transfer_diar_params_to_model_params(self.msdd_model, cfg) # Initialize clustering and embedding preparation instance (as a diarization encoder). self.clustering_embedding = ClusterEmbedding( cfg_diar_infer=cfg, cfg_msdd_model=self.msdd_model.cfg, speaker_model=self._speaker_model ) # Parameters for creating diarization results from MSDD outputs. self.clustering_max_spks = self.msdd_model._cfg.max_num_of_spks self.overlap_infer_spk_limit = cfg.diarizer.msdd_model.parameters.get( 'overlap_infer_spk_limit', self.clustering_max_spks ) def transfer_diar_params_to_model_params(self, msdd_model, cfg): """ Transfer the parameters that are needed for MSDD inference from the diarization inference config files to MSDD model config `msdd_model.cfg`. """ msdd_model.cfg.diarizer.out_dir = cfg.diarizer.out_dir msdd_model.cfg.test_ds.manifest_filepath = cfg.diarizer.manifest_filepath msdd_model.cfg.test_ds.emb_dir = cfg.diarizer.out_dir msdd_model.cfg.test_ds.batch_size = cfg.diarizer.msdd_model.parameters.infer_batch_size msdd_model.cfg.test_ds.seq_eval_mode = cfg.diarizer.msdd_model.parameters.seq_eval_mode msdd_model._cfg.max_num_of_spks = cfg.diarizer.clustering.parameters.max_num_speakers return msdd_model.cfg @rank_zero_only def save_to(self, save_path: str): """ Saves model instances (weights and configuration) into EFF archive. You can use "restore_from" method to fully restore instance from .nemo file. .nemo file is an archive (tar.gz) with the following: model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor model_wights.chpt - model checkpoint Args: save_path: Path to .nemo file where model instance should be saved """ self.clus_diar = self.clustering_embedding.clus_diar_model _NEURAL_DIAR_MODEL = "msdd_model.nemo" with tempfile.TemporaryDirectory() as tmpdir: config_yaml = os.path.join(tmpdir, _MODEL_CONFIG_YAML) spkr_model = os.path.join(tmpdir, _SPEAKER_MODEL) neural_diar_model = os.path.join(tmpdir, _NEURAL_DIAR_MODEL) self.clus_diar.to_config_file(path2yaml_file=config_yaml) if self.clus_diar.has_vad_model: vad_model = os.path.join(tmpdir, _VAD_MODEL) self.clus_diar._vad_model.save_to(vad_model) self.clus_diar._speaker_model.save_to(spkr_model) self.msdd_model.save_to(neural_diar_model) self.clus_diar.__make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir) def extract_standalone_speaker_model(self, prefix: str = 'msdd._speaker_model.') -> EncDecSpeakerLabelModel: """ MSDD model file contains speaker embedding model and MSDD model. This function extracts standalone speaker model and save it to `self.spk_emb_state_dict` to be loaded separately for clustering diarizer. Args: ext (str): File-name extension of the provided model path. Returns: standalone_model_path (str): Path to the extracted standalone model without speaker embedding extractor model. """ model_state_dict = self.msdd_model.state_dict() spk_emb_module_names = [] for name in model_state_dict.keys(): if prefix in name: spk_emb_module_names.append(name) spk_emb_state_dict = {} for name in spk_emb_module_names: org_name = name.replace(prefix, '') spk_emb_state_dict[org_name] = model_state_dict[name] _speaker_model = EncDecSpeakerLabelModel.from_config_dict(self.msdd_model.cfg.speaker_model_cfg) _speaker_model.load_state_dict(spk_emb_state_dict) return _speaker_model def _init_msdd_model(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]): """ Initialized MSDD model with the provided config. Load either from `.nemo` file or `.ckpt` checkpoint files. """ model_path = cfg.diarizer.msdd_model.model_path if model_path.endswith('.nemo'): logging.info(f"Using local nemo file from {model_path}") self.msdd_model = EncDecDiarLabelModel.restore_from(restore_path=model_path, map_location=cfg.device) elif model_path.endswith('.ckpt'): logging.info(f"Using local checkpoint from {model_path}") self.msdd_model = EncDecDiarLabelModel.load_from_checkpoint( checkpoint_path=model_path, map_location=cfg.device ) else: if model_path not in get_available_model_names(EncDecDiarLabelModel): logging.warning(f"requested {model_path} model name not available in pretrained models, instead") logging.info("Loading pretrained {} model from NGC".format(model_path)) self.msdd_model = EncDecDiarLabelModel.from_pretrained(model_name=model_path, map_location=cfg.device) # Load speaker embedding model state_dict which is loaded from the MSDD checkpoint. if self.use_speaker_model_from_ckpt: self._speaker_model = self.extract_standalone_speaker_model() else: self._speaker_model = None def get_pred_mat(self, data_list: List[Union[Tuple[int], List[torch.Tensor]]]) -> torch.Tensor: """ This module puts together the pairwise, two-speaker, predicted results to form a finalized matrix that has dimension of `(total_len, n_est_spks)`. The pairwise results are evenutally averaged. For example, in 4 speaker case (speaker 1, 2, 3, 4), the sum of the pairwise results (1, 2), (1, 3), (1, 4) are then divided by 3 to take average of the sigmoid values. Args: data_list (list): List containing data points from `test_data_collection` variable. `data_list` has sublists `data` as follows: data[0]: `target_spks` tuple Examples: (0, 1, 2) data[1]: Tensor containing estimaged sigmoid values. [[0.0264, 0.9995], [0.0112, 1.0000], ..., [1.0000, 0.0512]] Returns: sum_pred (Tensor): Tensor containing the averaged sigmoid values for each speaker. """ all_tups = tuple() for data in data_list: all_tups += data[0] n_est_spks = len(set(all_tups)) digit_map = dict(zip(sorted(set(all_tups)), range(n_est_spks))) total_len = max([sess[1].shape[1] for sess in data_list]) sum_pred = torch.zeros(total_len, n_est_spks) for (_dim_tup, pred_mat) in data_list: dim_tup = [digit_map[x] for x in _dim_tup] if len(pred_mat.shape) == 3: pred_mat = pred_mat.squeeze(0) if n_est_spks <= self.num_spks_per_model: sum_pred = pred_mat else: _end = pred_mat.shape[0] sum_pred[:_end, dim_tup] += pred_mat.cpu().float() sum_pred = sum_pred / (n_est_spks - 1) return sum_pred def get_integrated_preds_list( self, uniq_id_list: List[str], test_data_collection: List[Any], preds_list: List[torch.Tensor] ) -> List[torch.Tensor]: """ Merge multiple sequence inference outputs into a session level result. Args: uniq_id_list (list): List containing `uniq_id` values. test_data_collection (collections.DiarizationLabelEntity): Class instance that is containing session information such as targeted speaker indices, audio filepaths and RTTM filepaths. preds_list (list): List containing tensors filled with sigmoid values. Returns: output_list (list): List containing session-level estimated prediction matrix. """ session_dict = get_id_tup_dict(uniq_id_list, test_data_collection, preds_list) output_dict = {uniq_id: [] for uniq_id in uniq_id_list} for uniq_id, data_list in session_dict.items(): sum_pred = self.get_pred_mat(data_list) output_dict[uniq_id] = sum_pred.unsqueeze(0) output_list = [output_dict[uniq_id] for uniq_id in uniq_id_list] return output_list def get_emb_clus_infer(self, cluster_embeddings): """Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`. """ self.msdd_model.emb_sess_test_dict = cluster_embeddings.emb_sess_test_dict self.msdd_model.clus_test_label_dict = cluster_embeddings.clus_test_label_dict self.msdd_model.emb_seq_test = cluster_embeddings.emb_seq_test @torch.no_grad() def diarize(self) -> Optional[List[Optional[List[Tuple[DiarizationErrorRate, Dict]]]]]: """ Launch diarization pipeline which starts from VAD (or a oracle VAD stamp generation), initialization clustering and multiscale diarization decoder (MSDD). Note that the result of MSDD can include multiple speakers at the same time. Therefore, RTTM output of MSDD needs to be based on `make_rttm_with_overlap()` function that can generate overlapping timestamps. `self.run_overlap_aware_eval()` function performs DER evaluation. """ self.clustering_embedding.prepare_cluster_embs_infer() self.msdd_model.pairwise_infer = True self.get_emb_clus_infer(self.clustering_embedding) preds_list, targets_list, signal_lengths_list = self.run_pairwise_diarization() thresholds = list(self._cfg.diarizer.msdd_model.parameters.sigmoid_threshold) return [self.run_overlap_aware_eval(preds_list, threshold) for threshold in thresholds] def get_range_average( self, signals: torch.Tensor, emb_vectors: torch.Tensor, diar_window_index: int, test_data_collection: List[Any] ) -> Tuple[torch.Tensor, torch.Tensor, int]: """ This function is only used when `split_infer=True`. This module calculates cluster-average embeddings for the given short range. The range length is set by `self.diar_window_length`, and each cluster-average is only calculated for the specified range. Note that if the specified range does not contain some speakers (e.g. the range contains speaker 1, 3) compared to the global speaker sets (e.g. speaker 1, 2, 3, 4) then the missing speakers (e.g. speakers 2, 4) are assigned with zero-filled cluster-average speaker embedding. Args: signals (Tensor): Zero-padded Input multi-scale embedding sequences. Shape: (length, scale_n, emb_vectors, emb_dim) emb_vectors (Tensor): Cluster-average multi-scale embedding vectors. Shape: (length, scale_n, emb_vectors, emb_dim) diar_window_index (int): Index of split diarization wondows. test_data_collection (collections.DiarizationLabelEntity) Class instance that is containing session information such as targeted speaker indices, audio filepath and RTTM filepath. Returns: return emb_vectors_split (Tensor): Cluster-average speaker embedding vectors for each scale. emb_seq (Tensor): Zero-padded multi-scale embedding sequences. seq_len (int): Length of the sequence determined by `self.diar_window_length` variable. """ emb_vectors_split = torch.zeros_like(emb_vectors) uniq_id = os.path.splitext(os.path.basename(test_data_collection.audio_file))[0] clus_label_tensor = torch.tensor([x[-1] for x in self.msdd_model.clus_test_label_dict[uniq_id]]) for spk_idx in range(len(test_data_collection.target_spks)): stt, end = ( diar_window_index * self.diar_window_length, min((diar_window_index + 1) * self.diar_window_length, clus_label_tensor.shape[0]), ) seq_len = end - stt if stt < clus_label_tensor.shape[0]: target_clus_label_tensor = clus_label_tensor[stt:end] emb_seq, seg_length = ( signals[stt:end, :, :], min( self.diar_window_length, clus_label_tensor.shape[0] - diar_window_index * self.diar_window_length, ), ) target_clus_label_bool = target_clus_label_tensor == test_data_collection.target_spks[spk_idx] # There are cases where there is no corresponding speaker in split range, so any(target_clus_label_bool) could be False. if any(target_clus_label_bool): emb_vectors_split[:, :, spk_idx] = torch.mean(emb_seq[target_clus_label_bool], dim=0) # In case when the loop reaches the end of the sequence if seq_len < self.diar_window_length: emb_seq = torch.cat( [ emb_seq, torch.zeros(self.diar_window_length - seq_len, emb_seq.shape[1], emb_seq.shape[2]).to( signals.device ), ], dim=0, ) else: emb_seq = torch.zeros(self.diar_window_length, emb_vectors.shape[0], emb_vectors.shape[1]).to( signals.device ) seq_len = 0 return emb_vectors_split, emb_seq, seq_len def get_range_clus_avg_emb( self, test_batch: List[torch.Tensor], _test_data_collection: List[Any], device: torch.device('cpu') ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ This function is only used when `get_range_average` function is called. This module calculates cluster-average embeddings for the given short range. The range length is set by `self.diar_window_length`, and each cluster-average is only calculated for the specified range. Args: test_batch: (list) List containing embedding sequences, length of embedding sequences, ground truth labels (if exists) and initializing embedding vectors. test_data_collection: (list) List containing test-set dataloader contents. test_data_collection includes wav file path, RTTM file path, clustered speaker indices. Returns: sess_emb_vectors (Tensor): Tensor of cluster-average speaker embedding vectors. Shape: (batch_size, scale_n, emb_dim, 2*num_of_spks) sess_emb_seq (Tensor): Tensor of input multi-scale embedding sequences. Shape: (batch_size, length, scale_n, emb_dim) sess_sig_lengths (Tensor): Tensor of the actucal sequence length without zero-padding. Shape: (batch_size) """ _signals, signal_lengths, _targets, _emb_vectors = test_batch sess_emb_vectors, sess_emb_seq, sess_sig_lengths = [], [], [] split_count = torch.ceil(torch.tensor(_signals.shape[1] / self.diar_window_length)).int() self.max_pred_length = max(self.max_pred_length, self.diar_window_length * split_count) for k in range(_signals.shape[0]): signals, emb_vectors, test_data_collection = _signals[k], _emb_vectors[k], _test_data_collection[k] for diar_window_index in range(split_count): emb_vectors_split, emb_seq, seq_len = self.get_range_average( signals, emb_vectors, diar_window_index, test_data_collection ) sess_emb_vectors.append(emb_vectors_split) sess_emb_seq.append(emb_seq) sess_sig_lengths.append(seq_len) sess_emb_vectors = torch.stack(sess_emb_vectors).to(device) sess_emb_seq = torch.stack(sess_emb_seq).to(device) sess_sig_lengths = torch.tensor(sess_sig_lengths).to(device) return sess_emb_vectors, sess_emb_seq, sess_sig_lengths def diar_infer( self, test_batch: List[torch.Tensor], test_data_collection: List[Any] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Launch forward_infer() function by feeding the session-wise embedding sequences to get pairwise speaker prediction values. If split_infer is True, the input audio clips are broken into short sequences then cluster average embeddings are calculated for inference. Split-infer might result in an improved results if calculating clustering average on the shorter tim-espan can help speaker assignment. Args: test_batch: (list) List containing embedding sequences, length of embedding sequences, ground truth labels (if exists) and initializing embedding vectors. test_data_collection: (list) List containing test-set dataloader contents. test_data_collection includes wav file path, RTTM file path, clustered speaker indices. Returns: preds (Tensor): Tensor containing predicted values which are generated from MSDD model. targets (Tensor): Tensor containing binary ground-truth values. signal_lengths (Tensor): The actual Session length (number of steps = number of base-scale segments) without zero padding. """ signals, signal_lengths, _targets, emb_vectors = test_batch if self._cfg.diarizer.msdd_model.parameters.split_infer: split_count = torch.ceil(torch.tensor(signals.shape[1] / self.diar_window_length)).int() sess_emb_vectors, sess_emb_seq, sess_sig_lengths = self.get_range_clus_avg_emb( test_batch, test_data_collection, device=self.msdd_model.device ) with autocast(): _preds, scale_weights = self.msdd_model.forward_infer( input_signal=sess_emb_seq, input_signal_length=sess_sig_lengths, emb_vectors=sess_emb_vectors, targets=None, ) _preds = _preds.reshape(len(signal_lengths), split_count * self.diar_window_length, -1) _preds = _preds[:, : signals.shape[1], :] else: with autocast(): _preds, scale_weights = self.msdd_model.forward_infer( input_signal=signals, input_signal_length=signal_lengths, emb_vectors=emb_vectors, targets=None ) self.max_pred_length = max(_preds.shape[1], self.max_pred_length) preds = torch.zeros(_preds.shape[0], self.max_pred_length, _preds.shape[2]) targets = torch.zeros(_preds.shape[0], self.max_pred_length, _preds.shape[2]) preds[:, : _preds.shape[1], :] = _preds return preds, targets, signal_lengths @torch.no_grad() def run_pairwise_diarization(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """ Setup the parameters needed for batch inference and run batch inference. Note that each sample is pairwise speaker input. The pairwise inference results are reconstructed to make session-wise prediction results. Returns: integrated_preds_list: (list) List containing the session-wise speaker predictions in torch.tensor format. targets_list: (list) List containing the ground-truth labels in matrix format filled with 0 or 1. signal_lengths_list: (list) List containing the actual length of each sequence in session. """ self.out_rttm_dir = self.clustering_embedding.out_rttm_dir self.msdd_model.setup_test_data(self.msdd_model.cfg.test_ds) self.msdd_model.eval() cumul_sample_count = [0] preds_list, targets_list, signal_lengths_list = [], [], [] uniq_id_list = get_uniq_id_list_from_manifest(self.msdd_model.cfg.test_ds.manifest_filepath) test_data_collection = [d for d in self.msdd_model.data_collection] for sidx, test_batch in enumerate(tqdm(self.msdd_model.test_dataloader())): signals, signal_lengths, _targets, emb_vectors = test_batch cumul_sample_count.append(cumul_sample_count[-1] + signal_lengths.shape[0]) preds, targets, signal_lengths = self.diar_infer( test_batch, test_data_collection[cumul_sample_count[-2] : cumul_sample_count[-1]] ) if self._cfg.diarizer.msdd_model.parameters.seq_eval_mode: self.msdd_model._accuracy_test(preds, targets, signal_lengths) preds_list.extend(list(torch.split(preds, 1))) targets_list.extend(list(torch.split(targets, 1))) signal_lengths_list.extend(list(torch.split(signal_lengths, 1))) if self._cfg.diarizer.msdd_model.parameters.seq_eval_mode: f1_score, simple_acc = self.msdd_model.compute_accuracies() logging.info(f"Test Inference F1 score. {f1_score:.4f}, simple Acc. {simple_acc:.4f}") integrated_preds_list = self.get_integrated_preds_list(uniq_id_list, test_data_collection, preds_list) return integrated_preds_list, targets_list, signal_lengths_list def run_overlap_aware_eval( self, preds_list: List[torch.Tensor], threshold: float ) -> List[Optional[Tuple[DiarizationErrorRate, Dict]]]: """ Based on the predicted sigmoid values, render RTTM files then evaluate the overlap-aware diarization results. Args: preds_list: (list) List containing predicted pairwise speaker labels. threshold: (float) A floating-point threshold value that determines overlapped speech detection. - If threshold is 1.0, no overlap speech is detected and only detect major speaker. - If threshold is 0.0, all speakers are considered active at any time step. """ logging.info( f" [Threshold: {threshold:.4f}] [use_clus_as_main={self.use_clus_as_main}] [diar_window={self.diar_window_length}]" ) outputs = [] manifest_filepath = self.msdd_model.cfg.test_ds.manifest_filepath rttm_map = audio_rttm_map(manifest_filepath) for k, (collar, ignore_overlap) in enumerate(self.diar_eval_settings): all_reference, all_hypothesis = make_rttm_with_overlap( manifest_filepath, self.msdd_model.clus_test_label_dict, preds_list, threshold=threshold, infer_overlap=True, use_clus_as_main=self.use_clus_as_main, overlap_infer_spk_limit=self.overlap_infer_spk_limit, use_adaptive_thres=self.use_adaptive_thres, max_overlap_spks=self.max_overlap_spks, out_rttm_dir=self.out_rttm_dir, ) output = score_labels( rttm_map, all_reference, all_hypothesis, collar=collar, ignore_overlap=ignore_overlap, verbose=self._cfg.verbose, ) outputs.append(output) logging.info(f" \n") return outputs @classmethod def from_pretrained( cls, model_name: str, vad_model_name: str = 'vad_multilingual_marblenet', map_location: Optional[str] = None, verbose: bool = False, ): """ Instantiate a `NeuralDiarizer` to run Speaker Diarization. Args: model_name (str): Path/Name of the neural diarization model to load. vad_model_name (str): Path/Name of the voice activity detection (VAD) model to load. map_location (str): Optional str to map the instantiated model to a device (cpu, cuda). By default, (None), it will select a GPU if available, falling back to CPU otherwise. verbose (bool): Enable verbose logging when loading models/running diarization. Returns: `NeuralDiarizer` """ logging.setLevel(logging.INFO if verbose else logging.WARNING) cfg = NeuralDiarizerInferenceConfig.init_config( diar_model_path=model_name, vad_model_path=vad_model_name, map_location=map_location, verbose=verbose, ) return cls(cfg) def __call__( self, audio_filepath: str, batch_size: int = 64, num_workers: int = 1, max_speakers: Optional[int] = None, num_speakers: Optional[int] = None, out_dir: Optional[str] = None, verbose: bool = False, ) -> Union[Annotation, List[Annotation]]: """ Run the `NeuralDiarizer` inference pipeline. Args: audio_filepath (str, list): Audio path to run speaker diarization on. max_speakers (int): If known, the max number of speakers in the file(s). num_speakers (int): If known, the exact number of speakers in the file(s). batch_size (int): Batch size when running inference. num_workers (int): Number of workers to use in data-loading. out_dir (str): Path to store intermediate files during inference (default temp directory). Returns: `pyannote.Annotation` for each audio path, containing speaker labels and segment timestamps. """ if out_dir: os.makedirs(out_dir, exist_ok=True) with tempfile.TemporaryDirectory(dir=out_dir) as tmpdir: manifest_path = os.path.join(tmpdir, 'manifest.json') meta = [ { 'audio_filepath': audio_filepath, 'offset': 0, 'duration': None, 'label': 'infer', 'text': '-', 'num_speakers': num_speakers, 'rttm_filepath': None, 'uem_filepath': None, } ] with open(manifest_path, 'w') as f: f.write('\n'.join(json.dumps(x) for x in meta)) self._initialize_configs( manifest_path=manifest_path, max_speakers=max_speakers, num_speakers=num_speakers, tmpdir=tmpdir, batch_size=batch_size, num_workers=num_workers, verbose=verbose, ) self.msdd_model.cfg.test_ds.manifest_filepath = manifest_path self.diarize() pred_labels_clus = rttm_to_labels(f'{tmpdir}/pred_rttms/{Path(audio_filepath).stem}.rttm') return labels_to_pyannote_object(pred_labels_clus) def _initialize_configs( self, manifest_path: str, max_speakers: Optional[int], num_speakers: Optional[int], tmpdir: tempfile.TemporaryDirectory, batch_size: int, num_workers: int, verbose: bool, ) -> None: self._cfg.batch_size = batch_size self._cfg.num_workers = num_workers self._cfg.diarizer.manifest_filepath = manifest_path self._cfg.diarizer.out_dir = tmpdir self._cfg.verbose = verbose self._cfg.diarizer.clustering.parameters.oracle_num_speakers = num_speakers is not None if max_speakers: self._cfg.diarizer.clustering.parameters.max_num_speakers = max_speakers self.transfer_diar_params_to_model_params(self.msdd_model, self._cfg) @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ return EncDecDiarLabelModel.list_available_models()