Source code for nemo.collections.asr.models.clustering_diarizer

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import pickle as pkl
import shutil
import tarfile
import tempfile
from collections import defaultdict
from typing import List, Optional

import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf.listconfig import ListConfig
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm

from nemo.collections.asr.models.classification_models import EncDecClassificationModel
from nemo.collections.asr.models.label_models import ExtractSpeakerEmbeddingsModel
from nemo.collections.asr.parts.mixins import DiarizationMixin
from nemo.collections.asr.parts.speaker_utils import audio_rttm_map, perform_diarization, write_rttm2manifest
from nemo.collections.asr.parts.vad_utils import (
    generate_overlap_vad_seq,
    generate_vad_segment_table,
    get_vad_stream_status,
    prepare_manifest,
)
from nemo.core.classes import Model
from nemo.utils import logging, model_utils
from nemo.utils.exp_manager import NotFoundError

try:
    from torch.cuda.amp import autocast
except ImportError:
    from contextlib import contextmanager

    @contextmanager
    def autocast(enabled=None):
        yield


__all__ = ['ClusteringDiarizer']

_MODEL_CONFIG_YAML = "model_config.yaml"
_VAD_MODEL = "vad_model.nemo"
_SPEAKER_MODEL = "speaker_model.nemo"


def get_available_model_names(class_name):
    available_models = class_name.list_available_models()
    return list(map(lambda x: x.pretrained_model_name, available_models))


[docs]class ClusteringDiarizer(Model, DiarizationMixin): def __init__(self, cfg: DictConfig): cfg = model_utils.convert_model_config_to_dict_config(cfg) # Convert config to support Hydra 1.0+ instantiation cfg = model_utils.maybe_update_config_version(cfg) self._cfg = cfg self._out_dir = self._cfg.diarizer.out_dir if not os.path.exists(self._out_dir): os.mkdir(self._out_dir) # init vad model self.has_vad_model = False self.has_vad_model_to_save = False self._speaker_manifest_path = self._cfg.diarizer.speaker_embeddings.oracle_vad_manifest self.AUDIO_RTTM_MAP = None self.paths2audio_files = self._cfg.diarizer.paths2audio_files if self._cfg.diarizer.vad.model_path is not None: self._init_vad_model() self._vad_dir = os.path.join(self._out_dir, 'vad_outputs') self._vad_out_file = os.path.join(self._vad_dir, "vad_out.json") shutil.rmtree(self._vad_dir, ignore_errors=True) os.makedirs(self._vad_dir) # init speaker model self._init_speaker_model() self._num_speakers = self._cfg.diarizer.num_speakers self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
[docs] @classmethod def list_available_models(cls): pass
def _init_speaker_model(self): model_path = self._cfg.diarizer.speaker_embeddings.model_path if model_path is not None and model_path.endswith('.nemo'): self._speaker_model = ExtractSpeakerEmbeddingsModel.restore_from(model_path) logging.info("Speaker Model restored locally from {}".format(model_path)) else: if model_path not in get_available_model_names(ExtractSpeakerEmbeddingsModel): logging.warning( "requested {} model name not available in pretrained models, instead".format(model_path) ) model_path = "speakerdiarization_speakernet" logging.info("Loading pretrained {} model from NGC".format(model_path)) self._speaker_model = ExtractSpeakerEmbeddingsModel.from_pretrained(model_name=model_path) self._speaker_dir = os.path.join(self._out_dir, 'speaker_outputs')
[docs] def set_vad_model(self, vad_config): with open_dict(self._cfg): self._cfg.diarizer.vad = vad_config self._init_vad_model()
def _init_vad_model(self): model_path = self._cfg.diarizer.vad.model_path if model_path.endswith('.nemo'): self._vad_model = EncDecClassificationModel.restore_from(model_path) logging.info("VAD model loaded locally from {}".format(model_path)) else: if model_path not in get_available_model_names(EncDecClassificationModel): logging.warning( "requested {} model name not available in pretrained models, instead".format(model_path) ) model_path = "vad_telephony_marblenet" logging.info("Loading pretrained {} model from NGC".format(model_path)) self._vad_model = EncDecClassificationModel.from_pretrained(model_name=model_path) self._vad_window_length_in_sec = self._cfg.diarizer.vad.window_length_in_sec self._vad_shift_length_in_sec = self._cfg.diarizer.vad.shift_length_in_sec self.has_vad_model_to_save = True self.has_vad_model = True def _setup_vad_test_data(self, manifest_vad_input): vad_dl_config = { 'manifest_filepath': manifest_vad_input, 'sample_rate': self._cfg.sample_rate, 'vad_stream': True, 'labels': ['infer',], 'time_length': self._vad_window_length_in_sec, 'shift_length': self._vad_shift_length_in_sec, 'trim_silence': False, 'num_workers': self._cfg.num_workers, } self._vad_model.setup_test_data(test_data_config=vad_dl_config) def _setup_spkr_test_data(self, manifest_file): spk_dl_config = { 'manifest_filepath': manifest_file, 'sample_rate': self._cfg.sample_rate, 'batch_size': 1, 'time_length': self._cfg.diarizer.speaker_embeddings.window_length_in_sec, 'shift_length': self._cfg.diarizer.speaker_embeddings.shift_length_in_sec, 'trim_silence': False, 'embedding_dir': self._speaker_dir, 'labels': None, 'task': "diarization", 'num_workers': self._cfg.num_workers, } self._speaker_model.setup_test_data(spk_dl_config) def _run_vad(self, manifest_file): self._vad_model = self._vad_model.to(self._device) self._vad_model.eval() time_unit = int(self._vad_window_length_in_sec / self._vad_shift_length_in_sec) trunc = int(time_unit / 2) trunc_l = time_unit - trunc all_len = 0 data = [] for line in open(manifest_file, 'r'): file = os.path.basename(json.loads(line)['audio_filepath']) data.append(os.path.splitext(file)[0]) status = get_vad_stream_status(data) for i, test_batch in enumerate(tqdm(self._vad_model.test_dataloader())): test_batch = [x.to(self._device) for x in test_batch] with autocast(): log_probs = self._vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) probs = torch.softmax(log_probs, dim=-1) pred = probs[:, 1] if status[i] == 'start': to_save = pred[:-trunc] elif status[i] == 'next': to_save = pred[trunc:-trunc_l] elif status[i] == 'end': to_save = pred[trunc_l:] else: to_save = pred all_len += len(to_save) outpath = os.path.join(self._vad_dir, data[i] + ".frame") with open(outpath, "a") as fout: for f in range(len(to_save)): fout.write('{0:0.4f}\n'.format(to_save[f])) del test_batch if status[i] == 'end' or status[i] == 'single': all_len = 0 if not self._cfg.diarizer.vad.vad_decision_smoothing: # Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame; self.vad_pred_dir = self._vad_dir else: # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments. # smoothing_method would be either in majority vote (median) or average (mean) logging.info("Generating predictions with overlapping input segments") smoothing_pred_dir = generate_overlap_vad_seq( frame_pred_dir=self._vad_dir, smoothing_method=self._cfg.diarizer.vad.smoothing_params.method, overlap=self._cfg.diarizer.vad.smoothing_params.overlap, seg_len=self._vad_window_length_in_sec, shift_len=self._vad_shift_length_in_sec, num_workers=self._cfg.num_workers, ) self.vad_pred_dir = smoothing_pred_dir logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.") table_out_dir = generate_vad_segment_table( vad_pred_dir=self.vad_pred_dir, threshold=self._cfg.diarizer.vad.threshold, shift_len=self._vad_shift_length_in_sec, num_workers=self._cfg.num_workers, ) vad_table_list = [os.path.join(table_out_dir, key + ".txt") for key in self.AUDIO_RTTM_MAP] write_rttm2manifest(self._cfg.diarizer.paths2audio_files, vad_table_list, self._vad_out_file) self._speaker_manifest_path = self._vad_out_file def _extract_embeddings(self, manifest_file): logging.info("Extracting embeddings for Diarization") self._setup_spkr_test_data(manifest_file) uniq_names = [] out_embeddings = defaultdict(list) self._speaker_model = self._speaker_model.to(self._device) self._speaker_model.eval() with open(manifest_file, 'r') as manifest: for line in manifest.readlines(): line = line.strip() dic = json.loads(line) uniq_names.append(dic['audio_filepath'].split('/')[-1].rsplit('.', 1)[0]) for i, test_batch in enumerate(tqdm(self._speaker_model.test_dataloader())): test_batch = [x.to(self._device) for x in test_batch] audio_signal, audio_signal_len, labels, slices = test_batch with autocast(): _, embs = self._speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) emb_shape = embs.shape[-1] embs = embs.type(torch.float32) embs = embs.view(-1, emb_shape).cpu().detach().numpy() out_embeddings[uniq_names[i]].extend(embs) del test_batch embedding_dir = os.path.join(self._speaker_dir, 'embeddings') if not os.path.exists(embedding_dir): os.makedirs(embedding_dir, exist_ok=True) prefix = manifest_file.split('/')[-1].rsplit('.', 1)[-2] name = os.path.join(embedding_dir, prefix) self._embeddings_file = name + '_embeddings.pkl' pkl.dump(out_embeddings, open(self._embeddings_file, 'wb')) logging.info("Saved embedding files to {}".format(embedding_dir))
[docs] def path2audio_files_to_manifest(self, paths2audio_files): mfst_file = os.path.join(self._out_dir, 'manifest.json') with open(mfst_file, 'w') as fp: for audio_file in paths2audio_files: audio_file = audio_file.strip() entry = {'audio_filepath': audio_file, 'offset': 0.0, 'duration': None, 'text': '-', 'label': 'infer'} fp.write(json.dumps(entry) + '\n') return mfst_file
[docs] def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 1): """ """ if paths2audio_files: self.paths2audio_files = paths2audio_files else: if self._cfg.diarizer.paths2audio_files is None: raise ValueError("Pass path2audio files either through config or to diarize method") else: self.paths2audio_files = self._cfg.diarizer.paths2audio_files if type(self.paths2audio_files) is str and os.path.isfile(self.paths2audio_files): paths2audio_files = [] with open(self.paths2audio_files, 'r') as path2file: for audiofile in path2file.readlines(): audiofile = audiofile.strip() paths2audio_files.append(audiofile) elif type(self.paths2audio_files) in [list, ListConfig]: paths2audio_files = list(self.paths2audio_files) else: raise ValueError("paths2audio_files must be of type list or path to file containing audio files") self.AUDIO_RTTM_MAP = audio_rttm_map(paths2audio_files, self._cfg.diarizer.path2groundtruth_rttm_files) if self.has_vad_model: logging.info("Performing VAD") mfst_file = self.path2audio_files_to_manifest(paths2audio_files) self._dont_auto_split = False self._split_duration = 50 manifest_vad_input = mfst_file if not self._dont_auto_split: logging.info("Split long audio file to avoid CUDA memory issue") logging.debug("Try smaller split_duration if you still have CUDA memory issue") config = { 'manifest_filepath': mfst_file, 'time_length': self._vad_window_length_in_sec, 'split_duration': self._split_duration, 'num_workers': self._cfg.num_workers, } manifest_vad_input = prepare_manifest(config) else: logging.warning( "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it." ) self._setup_vad_test_data(manifest_vad_input) self._run_vad(manifest_vad_input) else: if not os.path.exists(self._speaker_manifest_path): raise NotFoundError("Oracle VAD based manifest file not found") self._extract_embeddings(self._speaker_manifest_path) out_rttm_dir = os.path.join(self._out_dir, 'pred_rttms') os.makedirs(out_rttm_dir, exist_ok=True) perform_diarization( embeddings_file=self._embeddings_file, reco2num=self._num_speakers, manifest_path=self._speaker_manifest_path, sample_rate=self._cfg.sample_rate, window=self._cfg.diarizer.speaker_embeddings.window_length_in_sec, shift=self._cfg.diarizer.speaker_embeddings.shift_length_in_sec, audio_rttm_map=self.AUDIO_RTTM_MAP, out_rttm_dir=out_rttm_dir, )
@staticmethod def __make_nemo_file_from_folder(filename, source_dir): with tarfile.open(filename, "w:gz") as tar: tar.add(source_dir, arcname="./") @rank_zero_only def save_to(self, save_path: str): """ Saves model instance (weights and configuration) into EFF archive or . You can use "restore_from" method to fully restore instance from .nemo file. .nemo file is an archive (tar.gz) with the following: model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor model_wights.chpt - model checkpoint Args: save_path: Path to .nemo file where model instance should be saved """ with tempfile.TemporaryDirectory() as tmpdir: config_yaml = os.path.join(tmpdir, _MODEL_CONFIG_YAML) spkr_model = os.path.join(tmpdir, _SPEAKER_MODEL) self.to_config_file(path2yaml_file=config_yaml) if self.has_vad_model: vad_model = os.path.join(tmpdir, _VAD_MODEL) self._vad_model.save_to(vad_model) self._speaker_model.save_to(spkr_model) self.__make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir) @staticmethod def __unpack_nemo_file(path2file: str, out_folder: str) -> str: if not os.path.exists(path2file): raise FileNotFoundError(f"{path2file} does not exist") tar = tarfile.open(path2file, "r:gz") tar.extractall(path=out_folder) tar.close() return out_folder
[docs] @classmethod def restore_from( cls, restore_path: str, override_config_path: Optional[str] = None, map_location: Optional[torch.device] = None, strict: bool = False, ): # Get path where the command is executed - the artifacts will be "retrieved" there # (original .nemo behavior) cwd = os.getcwd() with tempfile.TemporaryDirectory() as tmpdir: try: cls.__unpack_nemo_file(path2file=restore_path, out_folder=tmpdir) os.chdir(tmpdir) if override_config_path is None: config_yaml = os.path.join(tmpdir, _MODEL_CONFIG_YAML) else: config_yaml = override_config_path conf = OmegaConf.load(config_yaml) if os.path.exists(os.path.join(tmpdir, _VAD_MODEL)): conf.diarizer.vad.model_path = os.path.join(tmpdir, _VAD_MODEL) else: logging.info( f'Model {cls.__name__} does not contain a VAD model. A VAD model or manifest file with' f'speech segments need for diarization with this model' ) conf.diarizer.speaker_embeddings.model_path = os.path.join(tmpdir, _SPEAKER_MODEL) conf.restore_map_location = map_location OmegaConf.set_struct(conf, True) instance = cls(cfg=conf) logging.info(f'Model {cls.__name__} was successfully restored from {restore_path}.') finally: os.chdir(cwd) return instance