# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import pickle as pkl
import shutil
import tarfile
import tempfile
from copy import deepcopy
from typing import List, Optional
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm
from nemo.collections.asr.models.classification_models import EncDecClassificationModel
from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel
from nemo.collections.asr.parts.mixins.mixins import DiarizationMixin
from nemo.collections.asr.parts.utils.speaker_utils import (
audio_rttm_map,
get_embs_and_timestamps,
get_uniqname_from_filepath,
parse_scale_configs,
perform_clustering,
score_labels,
segments_manifest_to_subsegments_manifest,
write_rttm2manifest,
)
from nemo.collections.asr.parts.utils.vad_utils import (
generate_overlap_vad_seq,
generate_vad_segment_table,
get_vad_stream_status,
prepare_manifest,
)
from nemo.core.classes import Model
from nemo.utils import logging, model_utils
try:
from torch.cuda.amp import autocast
except ImportError:
from contextlib import contextmanager
@contextmanager
def autocast(enabled=None):
yield
__all__ = ['ClusteringDiarizer']
_MODEL_CONFIG_YAML = "model_config.yaml"
_VAD_MODEL = "vad_model.nemo"
_SPEAKER_MODEL = "speaker_model.nemo"
def get_available_model_names(class_name):
"lists available pretrained model names from NGC"
available_models = class_name.list_available_models()
return list(map(lambda x: x.pretrained_model_name, available_models))
[docs]class ClusteringDiarizer(Model, DiarizationMixin):
"""
Inference model Class for offline speaker diarization.
This class handles required functionality for diarization : Speech Activity Detection, Segmentation,
Extract Embeddings, Clustering, Resegmentation and Scoring.
All the parameters are passed through config file
"""
def __init__(self, cfg: DictConfig):
cfg = model_utils.convert_model_config_to_dict_config(cfg)
# Convert config to support Hydra 1.0+ instantiation
cfg = model_utils.maybe_update_config_version(cfg)
self._cfg = cfg
# Diarizer set up
self._diarizer_params = self._cfg.diarizer
# init vad model
self.has_vad_model = False
if not self._diarizer_params.oracle_vad:
if self._cfg.diarizer.vad.model_path is not None:
self._vad_params = self._cfg.diarizer.vad.parameters
self._init_vad_model()
# init speaker model
self.multiscale_embeddings_and_timestamps = {}
self._init_speaker_model()
self._speaker_params = self._cfg.diarizer.speaker_embeddings.parameters
self._speaker_dir = os.path.join(self._diarizer_params.out_dir, 'speaker_outputs')
shutil.rmtree(self._speaker_dir, ignore_errors=True)
os.makedirs(self._speaker_dir)
# Clustering params
self._cluster_params = self._diarizer_params.clustering.parameters
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
[docs] @classmethod
def list_available_models(cls):
pass
def _init_vad_model(self):
"""
Initialize VAD model with model name or path passed through config
"""
model_path = self._cfg.diarizer.vad.model_path
if model_path.endswith('.nemo'):
self._vad_model = EncDecClassificationModel.restore_from(model_path)
logging.info("VAD model loaded locally from {}".format(model_path))
else:
if model_path not in get_available_model_names(EncDecClassificationModel):
logging.warning(
"requested {} model name not available in pretrained models, instead".format(model_path)
)
model_path = "vad_telephony_marblenet"
logging.info("Loading pretrained {} model from NGC".format(model_path))
self._vad_model = EncDecClassificationModel.from_pretrained(model_name=model_path)
self._vad_window_length_in_sec = self._vad_params.window_length_in_sec
self._vad_shift_length_in_sec = self._vad_params.shift_length_in_sec
self.has_vad_model = True
def _init_speaker_model(self):
"""
Initialize speaker embedding model with model name or path passed through config
"""
model_path = self._cfg.diarizer.speaker_embeddings.model_path
if model_path is not None and model_path.endswith('.nemo'):
self._speaker_model = EncDecSpeakerLabelModel.restore_from(model_path)
logging.info("Speaker Model restored locally from {}".format(model_path))
elif model_path.endswith('.ckpt'):
self._speaker_model = EncDecSpeakerLabelModel.load_from_checkpoint(model_path)
logging.info("Speaker Model restored locally from {}".format(model_path))
else:
if model_path not in get_available_model_names(EncDecSpeakerLabelModel):
logging.warning(
"requested {} model name not available in pretrained models, instead".format(model_path)
)
model_path = "ecapa_tdnn"
logging.info("Loading pretrained {} model from NGC".format(model_path))
self._speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_name=model_path)
self.multiscale_args_dict = parse_scale_configs(
self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec,
self._diarizer_params.speaker_embeddings.parameters.shift_length_in_sec,
self._diarizer_params.speaker_embeddings.parameters.multiscale_weights,
)
def _setup_vad_test_data(self, manifest_vad_input):
vad_dl_config = {
'manifest_filepath': manifest_vad_input,
'sample_rate': self._cfg.sample_rate,
'batch_size': self._cfg.get('batch_size'),
'vad_stream': True,
'labels': ['infer',],
'window_length_in_sec': self._vad_window_length_in_sec,
'shift_length_in_sec': self._vad_shift_length_in_sec,
'trim_silence': False,
'num_workers': self._cfg.num_workers,
}
self._vad_model.setup_test_data(test_data_config=vad_dl_config)
def _setup_spkr_test_data(self, manifest_file):
spk_dl_config = {
'manifest_filepath': manifest_file,
'sample_rate': self._cfg.sample_rate,
'batch_size': self._cfg.get('batch_size'),
'trim_silence': False,
'labels': None,
'num_workers': self._cfg.num_workers,
}
self._speaker_model.setup_test_data(spk_dl_config)
def _run_vad(self, manifest_file):
"""
Run voice activity detection.
Get log probability of voice activity detection and smoothes using the post processing parameters.
Using generated frame level predictions generated manifest file for later speaker embedding extraction.
input:
manifest_file (str) : Manifest file containing path to audio file and label as infer
"""
shutil.rmtree(self._vad_dir, ignore_errors=True)
os.makedirs(self._vad_dir)
self._vad_model = self._vad_model.to(self._device)
self._vad_model.eval()
time_unit = int(self._vad_window_length_in_sec / self._vad_shift_length_in_sec)
trunc = int(time_unit / 2)
trunc_l = time_unit - trunc
all_len = 0
data = []
for line in open(manifest_file, 'r', encoding='utf-8'):
file = json.loads(line)['audio_filepath']
data.append(get_uniqname_from_filepath(file))
status = get_vad_stream_status(data)
for i, test_batch in enumerate(tqdm(self._vad_model.test_dataloader())):
test_batch = [x.to(self._device) for x in test_batch]
with autocast():
log_probs = self._vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1])
probs = torch.softmax(log_probs, dim=-1)
pred = probs[:, 1]
if status[i] == 'start':
to_save = pred[:-trunc]
elif status[i] == 'next':
to_save = pred[trunc:-trunc_l]
elif status[i] == 'end':
to_save = pred[trunc_l:]
else:
to_save = pred
all_len += len(to_save)
outpath = os.path.join(self._vad_dir, data[i] + ".frame")
with open(outpath, "a", encoding='utf-8') as fout:
for f in range(len(to_save)):
fout.write('{0:0.4f}\n'.format(to_save[f]))
del test_batch
if status[i] == 'end' or status[i] == 'single':
all_len = 0
if not self._vad_params.smoothing:
# Shift the window by 10ms to generate the frame and use the prediction of the window to represent the label for the frame;
self.vad_pred_dir = self._vad_dir
else:
# Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
# smoothing_method would be either in majority vote (median) or average (mean)
logging.info("Generating predictions with overlapping input segments")
smoothing_pred_dir = generate_overlap_vad_seq(
frame_pred_dir=self._vad_dir,
smoothing_method=self._vad_params.smoothing,
overlap=self._vad_params.overlap,
window_length_in_sec=self._vad_window_length_in_sec,
shift_length_in_sec=self._vad_shift_length_in_sec,
num_workers=self._cfg.num_workers,
)
self.vad_pred_dir = smoothing_pred_dir
logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")
table_out_dir = generate_vad_segment_table(
vad_pred_dir=self.vad_pred_dir,
postprocessing_params=self._vad_params,
shift_length_in_sec=self._vad_shift_length_in_sec,
num_workers=self._cfg.num_workers,
)
AUDIO_VAD_RTTM_MAP = deepcopy(self.AUDIO_RTTM_MAP.copy())
for key in AUDIO_VAD_RTTM_MAP:
AUDIO_VAD_RTTM_MAP[key]['rttm_filepath'] = os.path.join(table_out_dir, key + ".txt")
write_rttm2manifest(AUDIO_VAD_RTTM_MAP, self._vad_out_file)
self._speaker_manifest_path = self._vad_out_file
def _run_segmentation(self, window: float, shift: float, scale_tag: str = ''):
self.subsegments_manifest_path = os.path.join(self._speaker_dir, f'subsegments{scale_tag}.json')
logging.info(
f"Subsegmentation for embedding extraction:{scale_tag.replace('_',' ')}, {self.subsegments_manifest_path}"
)
self.subsegments_manifest_path = segments_manifest_to_subsegments_manifest(
segments_manifest_file=self._speaker_manifest_path,
subsegments_manifest_file=self.subsegments_manifest_path,
window=window,
shift=shift,
)
return None
def _perform_speech_activity_detection(self):
"""
Checks for type of speech activity detection from config. Choices are NeMo VAD,
external vad manifest and oracle VAD (generates speech activity labels from provided RTTM files)
"""
if self.has_vad_model:
self._auto_split = True
self._split_duration = 50
manifest_vad_input = self._diarizer_params.manifest_filepath
if self._auto_split:
logging.info("Split long audio file to avoid CUDA memory issue")
logging.debug("Try smaller split_duration if you still have CUDA memory issue")
config = {
'input': manifest_vad_input,
'window_length_in_sec': self._vad_window_length_in_sec,
'split_duration': self._split_duration,
'num_workers': self._cfg.num_workers,
}
manifest_vad_input = prepare_manifest(config)
else:
logging.warning(
"If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it."
)
self._setup_vad_test_data(manifest_vad_input)
self._run_vad(manifest_vad_input)
elif self._diarizer_params.vad.external_vad_manifest is not None:
self._speaker_manifest_path = self._diarizer_params.vad.external_vad_manifest
elif self._diarizer_params.oracle_vad:
self._speaker_manifest_path = os.path.join(self._speaker_dir, 'oracle_vad_manifest.json')
self._speaker_manifest_path = write_rttm2manifest(self.AUDIO_RTTM_MAP, self._speaker_manifest_path)
else:
raise ValueError(
"Only one of diarizer.oracle_vad, vad.model_path or vad.external_vad_manifest must be passed"
)
def _extract_embeddings(self, manifest_file: str):
"""
This method extracts speaker embeddings from segments passed through manifest_file
Optionally you may save the intermediate speaker embeddings for debugging or any use.
"""
logging.info("Extracting embeddings for Diarization")
self._setup_spkr_test_data(manifest_file)
self.embeddings = {}
self._speaker_model = self._speaker_model.to(self._device)
self._speaker_model.eval()
self.time_stamps = {}
all_embs = torch.empty([0])
for test_batch in tqdm(self._speaker_model.test_dataloader()):
test_batch = [x.to(self._device) for x in test_batch]
audio_signal, audio_signal_len, labels, slices = test_batch
with autocast():
_, embs = self._speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
emb_shape = embs.shape[-1]
embs = embs.view(-1, emb_shape)
all_embs = torch.cat((all_embs, embs.cpu().detach()), dim=0)
del test_batch
with open(manifest_file, 'r', encoding='utf-8') as manifest:
for i, line in enumerate(manifest.readlines()):
line = line.strip()
dic = json.loads(line)
uniq_name = get_uniqname_from_filepath(dic['audio_filepath'])
if uniq_name in self.embeddings:
self.embeddings[uniq_name] = torch.cat((self.embeddings[uniq_name], all_embs[i].view(1, -1)))
else:
self.embeddings[uniq_name] = all_embs[i].view(1, -1)
if uniq_name not in self.time_stamps:
self.time_stamps[uniq_name] = []
start = dic['offset']
end = start + dic['duration']
stamp = '{:.3f} {:.3f} '.format(start, end)
self.time_stamps[uniq_name].append(stamp)
if self._speaker_params.save_embeddings:
embedding_dir = os.path.join(self._speaker_dir, 'embeddings')
if not os.path.exists(embedding_dir):
os.makedirs(embedding_dir, exist_ok=True)
prefix = get_uniqname_from_filepath(manifest_file)
name = os.path.join(embedding_dir, prefix)
self._embeddings_file = name + f'_embeddings.pkl'
pkl.dump(self.embeddings, open(self._embeddings_file, 'wb'))
logging.info("Saved embedding files to {}".format(embedding_dir))
[docs] def path2audio_files_to_manifest(self, paths2audio_files, manifest_filepath):
with open(manifest_filepath, 'w', encoding='utf-8') as fp:
for audio_file in paths2audio_files:
audio_file = audio_file.strip()
entry = {'audio_filepath': audio_file, 'offset': 0.0, 'duration': None, 'text': '-', 'label': 'infer'}
fp.write(json.dumps(entry) + '\n')
[docs] def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 0):
"""
Diarize files provided thorugh paths2audio_files or manifest file
input:
paths2audio_files (List[str]): list of paths to file containing audio file
batch_size (int): batch_size considered for extraction of speaker embeddings and VAD computation
"""
self._out_dir = self._diarizer_params.out_dir
if not os.path.exists(self._out_dir):
os.mkdir(self._out_dir)
self._vad_dir = os.path.join(self._out_dir, 'vad_outputs')
self._vad_out_file = os.path.join(self._vad_dir, "vad_out.json")
if batch_size:
self._cfg.batch_size = batch_size
if paths2audio_files:
if type(paths2audio_files) is list:
self._diarizer_params.manifest_filepath = os.path.join(self._out_dir, 'paths2audio_filepath.json')
self.path2audio_files_to_manifest(paths2audio_files, self._diarizer_params.manifest_filepath)
else:
raise ValueError("paths2audio_files must be of type list of paths to file containing audio file")
self.AUDIO_RTTM_MAP = audio_rttm_map(self._diarizer_params.manifest_filepath)
out_rttm_dir = os.path.join(self._out_dir, 'pred_rttms')
os.makedirs(out_rttm_dir, exist_ok=True)
# Speech Activity Detection
self._perform_speech_activity_detection()
# Segmentation
for scale_idx, (window, shift) in self.multiscale_args_dict['scale_dict'].items():
# Segmentation for the current scale (scale_idx)
self._run_segmentation(window, shift, scale_tag=f'_scale{scale_idx}')
# Embedding Extraction for the current scale (scale_idx)
self._extract_embeddings(self.subsegments_manifest_path)
self.multiscale_embeddings_and_timestamps[scale_idx] = [self.embeddings, self.time_stamps]
embs_and_timestamps = get_embs_and_timestamps(
self.multiscale_embeddings_and_timestamps, self.multiscale_args_dict
)
# Clustering
all_reference, all_hypothesis = perform_clustering(
embs_and_timestamps=embs_and_timestamps,
AUDIO_RTTM_MAP=self.AUDIO_RTTM_MAP,
out_rttm_dir=out_rttm_dir,
clustering_params=self._cluster_params,
)
# TODO Resegmentation -> Coming Soon
# Scoring
score = score_labels(
self.AUDIO_RTTM_MAP,
all_reference,
all_hypothesis,
collar=self._diarizer_params.collar,
ignore_overlap=self._diarizer_params.ignore_overlap,
)
logging.info("Outputs are saved in {} directory".format(os.path.abspath(self._diarizer_params.out_dir)))
return score
@staticmethod
def __make_nemo_file_from_folder(filename, source_dir):
with tarfile.open(filename, "w:gz") as tar:
tar.add(source_dir, arcname="./")
@rank_zero_only
def save_to(self, save_path: str):
"""
Saves model instance (weights and configuration) into EFF archive or .
You can use "restore_from" method to fully restore instance from .nemo file.
.nemo file is an archive (tar.gz) with the following:
model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
model_wights.chpt - model checkpoint
Args:
save_path: Path to .nemo file where model instance should be saved
"""
# TODO: Why does this override the main save_to?
with tempfile.TemporaryDirectory() as tmpdir:
config_yaml = os.path.join(tmpdir, _MODEL_CONFIG_YAML)
spkr_model = os.path.join(tmpdir, _SPEAKER_MODEL)
self.to_config_file(path2yaml_file=config_yaml)
if self.has_vad_model:
vad_model = os.path.join(tmpdir, _VAD_MODEL)
self._vad_model.save_to(vad_model)
self._speaker_model.save_to(spkr_model)
self.__make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir)
@staticmethod
def __unpack_nemo_file(path2file: str, out_folder: str) -> str:
if not os.path.exists(path2file):
raise FileNotFoundError(f"{path2file} does not exist")
tar = tarfile.open(path2file, "r:gz")
tar.extractall(path=out_folder)
tar.close()
return out_folder
[docs] @classmethod
def restore_from(
cls,
restore_path: str,
override_config_path: Optional[str] = None,
map_location: Optional[torch.device] = None,
strict: bool = False,
):
# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .nemo behavior)
cwd = os.getcwd()
with tempfile.TemporaryDirectory() as tmpdir:
try:
cls.__unpack_nemo_file(path2file=restore_path, out_folder=tmpdir)
os.chdir(tmpdir)
if override_config_path is None:
config_yaml = os.path.join(tmpdir, _MODEL_CONFIG_YAML)
else:
config_yaml = override_config_path
conf = OmegaConf.load(config_yaml)
if os.path.exists(os.path.join(tmpdir, _VAD_MODEL)):
conf.diarizer.vad.model_path = os.path.join(tmpdir, _VAD_MODEL)
else:
logging.info(
f'Model {cls.__name__} does not contain a VAD model. A VAD model or manifest file with'
f'speech segments need for diarization with this model'
)
conf.diarizer.speaker_embeddings.model_path = os.path.join(tmpdir, _SPEAKER_MODEL)
conf.restore_map_location = map_location
OmegaConf.set_struct(conf, True)
instance = cls(cfg=conf)
logging.info(f'Model {cls.__name__} was successfully restored from {restore_path}.')
finally:
os.chdir(cwd)
return instance