# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright (c) 2018 Ryan Leary
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# This file contains code artifacts adapted from https://github.com/ryanleary/patter
import copy
import io
import os
import random
import subprocess
from tempfile import NamedTemporaryFile
from typing import List, Optional, Union
import librosa
import numpy as np
import soundfile as sf
import webdataset as wd
from omegaconf import DictConfig, OmegaConf
from scipy import signal
from torch.utils.data import IterableDataset
from nemo.collections.asr.parts import collections, parsers
from nemo.collections.asr.parts.segment import AudioSegment
from nemo.utils import logging
try:
from nemo.collections.asr.parts import numba_utils
HAVE_NUMBA = True
except (ImportError, ModuleNotFoundError):
HAVE_NUMBA = False
def read_one_audiosegment(manifest, target_sr, rng, tarred_audio=False, audio_dataset=None):
if tarred_audio:
if audio_dataset is None:
raise TypeError("Expected augmentation dataset but got None")
audio_file, file_id = next(audio_dataset)
manifest_idx = manifest.mapping[file_id]
manifest_entry = manifest[manifest_idx]
offset = 0 if manifest_entry.offset is None else manifest_entry.offset
duration = 0 if manifest_entry.duration is None else manifest_entry.duration
else:
audio_record = rng.sample(manifest.data, 1)[0]
audio_file = audio_record.audio_file
offset = 0 if audio_record.offset is None else audio_record.offset
duration = 0 if audio_record.duration is None else audio_record.duration
return AudioSegment.from_file(audio_file, target_sr=target_sr, offset=offset, duration=duration)
class Perturbation(object):
def max_augmentation_length(self, length):
return length
def perturb(self, data):
raise NotImplementedError
[docs]class SpeedPerturbation(Perturbation):
"""
Performs Speed Augmentation by re-sampling the data to a different sampling rate,
which does not preserve pitch.
Note: This is a very slow operation for online augmentation. If space allows,
it is preferable to pre-compute and save the files to augment the dataset.
Args:
sr: Original sampling rate.
resample_type: Type of resampling operation that will be performed.
For better speed using `resampy`'s fast resampling method, use `resample_type='kaiser_fast'`.
For high-quality resampling, set `resample_type='kaiser_best'`.
To use `scipy.signal.resample`, set `resample_type='fft'` or `resample_type='scipy'`
min_speed_rate: Minimum sampling rate modifier.
max_speed_rate: Maximum sampling rate modifier.
num_rates: Number of discrete rates to allow. Can be a positive or negative
integer.
If a positive integer greater than 0 is provided, the range of
speed rates will be discretized into `num_rates` values.
If a negative integer or 0 is provided, the full range of speed rates
will be sampled uniformly.
Note: If a positive integer is provided and the resultant discretized
range of rates contains the value '1.0', then those samples with rate=1.0,
will not be augmented at all and simply skipped. This is to unnecessary
augmentation and increase computation time. Effective augmentation chance
in such a case is = `prob * (num_rates - 1 / num_rates) * 100`% chance
where `prob` is the global probability of a sample being augmented.
rng: Random seed number.
"""
def __init__(self, sr, resample_type, min_speed_rate=0.9, max_speed_rate=1.1, num_rates=5, rng=None):
min_rate = min(min_speed_rate, max_speed_rate)
if min_rate < 0.0:
raise ValueError("Minimum sampling rate modifier must be > 0.")
if resample_type not in ('kaiser_best', 'kaiser_fast', 'fft', 'scipy'):
raise ValueError("Supported `resample_type` values are ('kaiser_best', 'kaiser_fast', 'fft', 'scipy')")
self._sr = sr
self._min_rate = min_speed_rate
self._max_rate = max_speed_rate
self._num_rates = num_rates
if num_rates > 0:
self._rates = np.linspace(self._min_rate, self._max_rate, self._num_rates, endpoint=True)
self._res_type = resample_type
self._rng = random.Random() if rng is None else rng
[docs] def max_augmentation_length(self, length):
return length * self._max_rate
[docs] def perturb(self, data):
# Select speed rate either from choice or random sample
if self._num_rates < 0:
speed_rate = self._rng.uniform(self._min_rate, self._max_rate)
else:
speed_rate = self._rng.choice(self._rates)
# Skip perturbation in case of identity speed rate
if speed_rate == 1.0:
return
new_sr = int(self._sr * speed_rate)
data._samples = librosa.core.resample(data._samples, self._sr, new_sr, res_type=self._res_type)
[docs]class TimeStretchPerturbation(Perturbation):
"""
Time-stretch an audio series by a fixed rate while preserving pitch, based on [1, 2].
Note:
This is a simplified implementation, intended primarily for reference and pedagogical purposes.
It makes no attempt to handle transients, and is likely to produce audible artifacts.
Reference
[1] [Ellis, D. P. W. “A phase vocoder in Matlab.” Columbia University, 2002.]
(http://www.ee.columbia.edu/~dpwe/resources/matlab/pvoc/)
[2] [librosa.effects.time_stretch]
(https://librosa.github.io/librosa/generated/librosa.effects.time_stretch.html)
Args:
min_speed_rate: Minimum sampling rate modifier.
max_speed_rate: Maximum sampling rate modifier.
num_rates: Number of discrete rates to allow. Can be a positive or negative
integer.
If a positive integer greater than 0 is provided, the range of
speed rates will be discretized into `num_rates` values.
If a negative integer or 0 is provided, the full range of speed rates
will be sampled uniformly.
Note: If a positive integer is provided and the resultant discretized
range of rates contains the value '1.0', then those samples with rate=1.0,
will not be augmented at all and simply skipped. This is to avoid unnecessary
augmentation and increase computation time. Effective augmentation chance
in such a case is = `prob * (num_rates - 1 / num_rates) * 100`% chance
where `prob` is the global probability of a sample being augmented.
n_fft: Number of fft filters to be computed.
rng: Random seed number.
"""
def __init__(self, min_speed_rate=0.9, max_speed_rate=1.1, num_rates=5, n_fft=512, rng=None):
min_rate = min(min_speed_rate, max_speed_rate)
if min_rate < 0.0:
raise ValueError("Minimum sampling rate modifier must be > 0.")
self._min_rate = min_speed_rate
self._max_rate = max_speed_rate
self._num_rates = num_rates
if num_rates > 0:
self._rates = np.linspace(self._min_rate, self._max_rate, self._num_rates, endpoint=True)
self._rng = random.Random() if rng is None else rng
# Pre-compute constants
self._n_fft = int(n_fft)
self._hop_length = int(n_fft // 2)
# Pre-allocate buffers
self._phi_advance_fast = np.linspace(0, np.pi * self._hop_length, self._hop_length + 1)
self._scale_buffer_fast = np.empty(self._hop_length + 1, dtype=np.float32)
self._phi_advance_slow = np.linspace(0, np.pi * self._n_fft, self._n_fft + 1)
self._scale_buffer_slow = np.empty(self._n_fft + 1, dtype=np.float32)
[docs] def max_augmentation_length(self, length):
return length * self._max_rate
[docs] def perturb(self, data):
# Select speed rate either from choice or random sample
if self._num_rates < 0:
speed_rate = self._rng.uniform(self._min_rate, self._max_rate)
else:
speed_rate = self._rng.choice(self._rates)
# Skip perturbation in case of identity speed rate
if speed_rate == 1.0:
return
# Increase `n_fft` based on task (speed up or slow down audio)
# This greatly reduces upper bound of maximum time taken
# to compute slowed down audio segments.
if speed_rate >= 1.0: # Speed up audio
fft_multiplier = 1
phi_advance = self._phi_advance_fast
scale_buffer = self._scale_buffer_fast
else: # Slow down audio
fft_multiplier = 2
phi_advance = self._phi_advance_slow
scale_buffer = self._scale_buffer_slow
n_fft = int(self._n_fft * fft_multiplier)
hop_length = int(self._hop_length * fft_multiplier)
# Perform short-term Fourier transform (STFT)
stft = librosa.core.stft(data._samples, n_fft=n_fft, hop_length=hop_length)
# Stretch by phase vocoding
if HAVE_NUMBA:
stft_stretch = numba_utils.phase_vocoder(stft, speed_rate, phi_advance, scale_buffer)
else:
stft_stretch = librosa.core.phase_vocoder(stft, speed_rate, hop_length)
# Predict the length of y_stretch
len_stretch = int(round(len(data._samples) / speed_rate))
# Invert the STFT
y_stretch = librosa.core.istft(
stft_stretch, dtype=data._samples.dtype, hop_length=hop_length, length=len_stretch
)
data._samples = y_stretch
[docs]class GainPerturbation(Perturbation):
"""
Applies random gain to the audio.
Args:
min_gain_dbfs (float): Min gain level in dB
max_gain_dbfs (float): Max gain level in dB
rng: Random number generator
"""
def __init__(self, min_gain_dbfs=-10, max_gain_dbfs=10, rng=None):
self._min_gain_dbfs = min_gain_dbfs
self._max_gain_dbfs = max_gain_dbfs
self._rng = random.Random() if rng is None else rng
[docs] def perturb(self, data):
gain = self._rng.uniform(self._min_gain_dbfs, self._max_gain_dbfs)
# logging.debug("gain: %d", gain)
data._samples = data._samples * (10.0 ** (gain / 20.0))
[docs]class ImpulsePerturbation(Perturbation):
"""
Convolves audio with a Room Impulse Response.
Args:
manifest_path (list): Manifest file for RIRs
audio_tar_filepaths (list): Tar files, if RIR audio files are tarred
shuffle_n (int): Shuffle parameter for shuffling buffered files from the tar files
shift_impulse (bool): Shift impulse response to adjust for delay at the beginning
"""
def __init__(self, manifest_path=None, rng=None, audio_tar_filepaths=None, shuffle_n=128, shift_impulse=False):
self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True)
self._audiodataset = None
self._tarred_audio = False
self._shift_impulse = shift_impulse
self._data_iterator = None
if audio_tar_filepaths:
self._tarred_audio = True
self._audiodataset = AugmentationDataset(manifest_path, audio_tar_filepaths, shuffle_n)
self._data_iterator = iter(self._audiodataset)
self._rng = random.Random() if rng is None else rng
[docs] def perturb(self, data):
impulse = read_one_audiosegment(
self._manifest,
data.sample_rate,
self._rng,
tarred_audio=self._tarred_audio,
audio_dataset=self._data_iterator,
)
if not self._shift_impulse:
impulse_norm = (impulse.samples - min(impulse.samples)) / (max(impulse.samples) - min(impulse.samples))
data._samples = signal.fftconvolve(data._samples, impulse_norm, "same")
else:
# Find peak and shift peak to left
impulse_norm = (impulse.samples - min(impulse.samples)) / (max(impulse.samples) - min(impulse.samples))
max_ind = np.argmax(np.abs(impulse_norm))
impulse_resp = impulse_norm[max_ind:]
delay_after = len(impulse_resp)
data._samples = signal.fftconvolve(data._samples, impulse_resp, "full")[:-delay_after]
[docs]class ShiftPerturbation(Perturbation):
"""
Perturbs audio by shifting the audio in time by a random amount between min_shift_ms and max_shift_ms.
The final length of the audio is kept unaltered by padding the audio with zeros.
Args:
min_shift_ms (float): Minimum time in milliseconds by which audio will be shifted
max_shift_ms (float): Maximum time in milliseconds by which audio will be shifted
rng: Random number generator
"""
def __init__(self, min_shift_ms=-5.0, max_shift_ms=5.0, rng=None):
self._min_shift_ms = min_shift_ms
self._max_shift_ms = max_shift_ms
self._rng = random.Random() if rng is None else rng
[docs] def perturb(self, data):
shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
if abs(shift_ms) / 1000 > data.duration:
# TODO: do something smarter than just ignore this condition
return
shift_samples = int(shift_ms * data.sample_rate // 1000)
# logging.debug("shift: %s", shift_samples)
if shift_samples < 0:
data._samples[-shift_samples:] = data._samples[:shift_samples]
data._samples[:-shift_samples] = 0
elif shift_samples > 0:
data._samples[:-shift_samples] = data._samples[shift_samples:]
data._samples[-shift_samples:] = 0
[docs]class NoisePerturbation(Perturbation):
"""
Perturbation that adds noise to input audio.
Args:
manifest_path (str): Manifest file with paths to noise files
min_snr_db (float): Minimum SNR of audio after noise is added
max_snr_db (float): Maximum SNR of audio after noise is added
max_gain_db (float): Maximum gain that can be applied on the noise sample
audio_tar_filepaths (list) : Tar files, if noise audio files are tarred
shuffle_n (int): Shuffle parameter for shuffling buffered files from the tar files
orig_sr (int): Original sampling rate of the noise files
rng: Random number generator
"""
def __init__(
self,
manifest_path=None,
min_snr_db=10,
max_snr_db=50,
max_gain_db=300.0,
rng=None,
audio_tar_filepaths=None,
shuffle_n=100,
orig_sr=16000,
):
self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True)
self._audiodataset = None
self._tarred_audio = False
self._orig_sr = orig_sr
self._data_iterator = None
if audio_tar_filepaths:
self._tarred_audio = True
self._audiodataset = AugmentationDataset(manifest_path, audio_tar_filepaths, shuffle_n)
self._data_iterator = iter(self._audiodataset)
self._rng = random.Random() if rng is None else rng
self._min_snr_db = min_snr_db
self._max_snr_db = max_snr_db
self._max_gain_db = max_gain_db
@property
def orig_sr(self):
return self._orig_sr
[docs] def get_one_noise_sample(self, target_sr):
return read_one_audiosegment(
self._manifest, target_sr, self._rng, tarred_audio=self._tarred_audio, audio_dataset=self._data_iterator
)
[docs] def perturb(self, data):
noise = read_one_audiosegment(
self._manifest,
data.sample_rate,
self._rng,
tarred_audio=self._tarred_audio,
audio_dataset=self._data_iterator,
)
self.perturb_with_input_noise(data, noise)
[docs] def perturb_with_foreground_noise(
self, data, noise, data_rms=None, max_noise_dur=2, max_additions=1,
):
snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db)
if not data_rms:
data_rms = data.rms_db
noise_gain_db = min(data_rms - noise.rms_db - snr_db, self._max_gain_db)
n_additions = self._rng.randint(1, max_additions)
for i in range(n_additions):
noise_dur = self._rng.uniform(0.0, max_noise_dur)
start_time = self._rng.uniform(0.0, noise.duration)
start_sample = int(round(start_time * noise.sample_rate))
end_sample = int(round(min(noise.duration, (start_time + noise_dur)) * noise.sample_rate))
noise_samples = np.copy(noise._samples[start_sample:end_sample])
# adjust gain for snr purposes and superimpose
noise_samples *= 10.0 ** (noise_gain_db / 20.0)
if noise_samples.shape[0] > data._samples.shape[0]:
noise_samples = noise_samples[0 : data._samples.shape[0]]
noise_idx = self._rng.randint(0, data._samples.shape[0] - noise_samples.shape[0])
data._samples[noise_idx : noise_idx + noise_samples.shape[0]] += noise_samples
[docs]class WhiteNoisePerturbation(Perturbation):
"""
Perturbation that adds white noise to an audio file in the training dataset.
Args:
min_level (int): Minimum level in dB at which white noise should be added
max_level (int): Maximum level in dB at which white noise should be added
rng: Random number generator
"""
def __init__(self, min_level=-90, max_level=-46, rng=None):
self.min_level = int(min_level)
self.max_level = int(max_level)
self._rng = np.random.RandomState() if rng is None else rng
[docs] def perturb(self, data):
noise_level_db = self._rng.randint(self.min_level, self.max_level, dtype='int32')
noise_signal = self._rng.randn(data._samples.shape[0]) * (10.0 ** (noise_level_db / 20.0))
data._samples += noise_signal
[docs]class RirAndNoisePerturbation(Perturbation):
"""
RIR augmentation with additive foreground and background noise.
In this implementation audio data is augmented by first convolving the audio with a Room Impulse Response
and then adding foreground noise and background noise at various SNRs. RIR, foreground and background noises
should either be supplied with a manifest file or as tarred audio files (faster).
Different sets of noise audio files based on the original sampling rate of the noise. This is useful while
training a mixed sample rate model. For example, when training a mixed model with 8 kHz and 16 kHz audio with a
target sampling rate of 16 kHz, one would want to augment 8 kHz data with 8 kHz noise rather than 16 kHz noise.
Args:
rir_manifest_path: Manifest file for RIRs
rir_tar_filepaths: Tar files, if RIR audio files are tarred
rir_prob: Probability of applying a RIR
noise_manifest_paths: Foreground noise manifest path
min_snr_db: Min SNR for foreground noise
max_snr_db: Max SNR for background noise,
noise_tar_filepaths: Tar files, if noise files are tarred
apply_noise_rir: Whether to convolve foreground noise with a a random RIR
orig_sample_rate: Original sampling rate of foreground noise audio
max_additions: Max number of times foreground noise is added to an utterance,
max_duration: Max duration of foreground noise
bg_noise_manifest_paths: Background noise manifest path
bg_min_snr_db: Min SNR for background noise
bg_max_snr_db: Max SNR for background noise
bg_noise_tar_filepaths: Tar files, if noise files are tarred
bg_orig_sample_rate: Original sampling rate of background noise audio
"""
def __init__(
self,
rir_manifest_path=None,
rir_prob=0.5,
noise_manifest_paths=None,
min_snr_db=0,
max_snr_db=50,
rir_tar_filepaths=None,
rir_shuffle_n=100,
noise_tar_filepaths=None,
apply_noise_rir=False,
orig_sample_rate=None,
max_additions=5,
max_duration=2.0,
bg_noise_manifest_paths=None,
bg_min_snr_db=10,
bg_max_snr_db=50,
bg_noise_tar_filepaths=None,
bg_orig_sample_rate=None,
):
logging.info("Called Rir aug init")
self._rir_prob = rir_prob
self._rng = random.Random()
self._rir_perturber = ImpulsePerturbation(
manifest_path=rir_manifest_path,
audio_tar_filepaths=rir_tar_filepaths,
shuffle_n=rir_shuffle_n,
shift_impulse=True,
)
self._fg_noise_perturbers = {}
self._bg_noise_perturbers = {}
if noise_manifest_paths:
for i in range(len(noise_manifest_paths)):
if orig_sample_rate is None:
orig_sr = 16000
else:
orig_sr = orig_sample_rate[i]
self._fg_noise_perturbers[orig_sr] = NoisePerturbation(
manifest_path=noise_manifest_paths[i],
min_snr_db=min_snr_db[i],
max_snr_db=max_snr_db[i],
audio_tar_filepaths=noise_tar_filepaths[i],
orig_sr=orig_sr,
)
self._max_additions = max_additions
self._max_duration = max_duration
if bg_noise_manifest_paths:
for i in range(len(bg_noise_manifest_paths)):
if bg_orig_sample_rate is None:
orig_sr = 16000
else:
orig_sr = bg_orig_sample_rate[i]
self._bg_noise_perturbers[orig_sr] = NoisePerturbation(
manifest_path=bg_noise_manifest_paths[i],
min_snr_db=bg_min_snr_db[i],
max_snr_db=bg_max_snr_db[i],
audio_tar_filepaths=bg_noise_tar_filepaths[i],
orig_sr=orig_sr,
)
self._apply_noise_rir = apply_noise_rir
[docs] def perturb(self, data):
prob = self._rng.uniform(0.0, 1.0)
if prob < self._rir_prob:
self._rir_perturber.perturb(data)
orig_sr = data.orig_sr
if orig_sr not in self._fg_noise_perturbers:
orig_sr = max(self._fg_noise_perturbers.keys())
fg_perturber = self._fg_noise_perturbers[orig_sr]
orig_sr = data.orig_sr
if orig_sr not in self._bg_noise_perturbers:
orig_sr = max(self._bg_noise_perturbers.keys())
bg_perturber = self._bg_noise_perturbers[orig_sr]
data_rms = data.rms_db
noise = fg_perturber.get_one_noise_sample(data.sample_rate)
if self._apply_noise_rir:
self._rir_perturber.perturb(noise)
fg_perturber.perturb_with_foreground_noise(
data, noise, data_rms=data_rms, max_noise_dur=self._max_duration, max_additions=self._max_additions
)
noise = bg_perturber.get_one_noise_sample(data.sample_rate)
bg_perturber.perturb_with_input_noise(data, noise, data_rms=data_rms)
[docs]class TranscodePerturbation(Perturbation):
"""
Audio codec augmentation. This implementation uses sox to transcode audio with low rate audio codecs,
so users need to make sure that the installed sox version supports the codecs used here (G711 and amr-nb).
Args:
rng: Random number generator
"""
def __init__(self, rng=None):
self._rng = np.random.RandomState() if rng is None else rng
self._codecs = ["g711", "amr-nb"]
[docs] def perturb(self, data):
att_factor = 0.8
max_level = np.max(np.abs(data._samples))
norm_factor = att_factor / max_level
norm_samples = norm_factor * data._samples
orig_f = NamedTemporaryFile(suffix=".wav")
sf.write(orig_f.name, norm_samples.transpose(), 16000)
codec_ind = random.randint(0, len(self._codecs) - 1)
if self._codecs[codec_ind] == "amr-nb":
transcoded_f = NamedTemporaryFile(suffix="_amr.wav")
rates = list(range(0, 8))
rate = rates[random.randint(0, len(rates) - 1)]
_ = subprocess.check_output(
f"sox {orig_f.name} -V0 -C {rate} -t amr-nb - | sox -t amr-nb - -V0 -b 16 -r 16000 {transcoded_f.name}",
shell=True,
)
elif self._codecs[codec_ind] == "g711":
transcoded_f = NamedTemporaryFile(suffix="_g711.wav")
_ = subprocess.check_output(
f"sox {orig_f.name} -V0 -r 8000 -c 1 -e a-law {transcoded_f.name}", shell=True
)
new_data = AudioSegment.from_file(transcoded_f.name, target_sr=16000)
data._samples = new_data._samples[0 : data._samples.shape[0]]
return
perturbation_types = {
"speed": SpeedPerturbation,
"time_stretch": TimeStretchPerturbation,
"gain": GainPerturbation,
"impulse": ImpulsePerturbation,
"shift": ShiftPerturbation,
"noise": NoisePerturbation,
"white_noise": WhiteNoisePerturbation,
"rir_noise_aug": RirAndNoisePerturbation,
"transcode_aug": TranscodePerturbation,
}
def register_perturbation(name: str, perturbation: Perturbation):
if name in perturbation_types.keys():
raise KeyError(
f"Perturbation with the name {name} exists. " f"Type of perturbation : {perturbation_types[name]}."
)
perturbation_types[name] = perturbation
class AudioAugmentor(object):
def __init__(self, perturbations=None, rng=None):
self._rng = random.Random() if rng is None else rng
self._pipeline = perturbations if perturbations is not None else []
def perturb(self, segment):
for (prob, p) in self._pipeline:
if self._rng.random() < prob:
p.perturb(segment)
return
def max_augmentation_length(self, length):
newlen = length
for (prob, p) in self._pipeline:
newlen = p.max_augmentation_length(newlen)
return newlen
@classmethod
def from_config(cls, config):
ptbs = []
for p in config:
if p['aug_type'] not in perturbation_types:
logging.warning("%s perturbation not known. Skipping.", p['aug_type'])
continue
perturbation = perturbation_types[p['aug_type']]
ptbs.append((p['prob'], perturbation(**p['cfg'])))
return cls(perturbations=ptbs)
def process_augmentations(augmenter) -> Optional[AudioAugmentor]:
"""Process list of online data augmentations.
Accepts either an AudioAugmentor object with pre-defined augmentations,
or a dictionary that points to augmentations that have been defined.
If a dictionary is passed, must follow the below structure:
Dict[str, Dict[str, Any]]: Which refers to a dictionary of string
names for augmentations, defined in `asr/parts/perturb.py`.
The inner dictionary may contain key-value arguments of the specific
augmentation, along with an essential key `prob`. `prob` declares the
probability of the augmentation being applied, and must be a float
value in the range [0, 1].
# Example in YAML config file
Augmentations are generally applied only during training, so we can add
these augmentations to our yaml config file, and modify the behaviour
for training and evaluation.
```yaml
AudioToSpeechLabelDataLayer:
... # Parameters shared between train and evaluation time
train:
augmentor:
shift:
prob: 0.5
min_shift_ms: -5.0
max_shift_ms: 5.0
white_noise:
prob: 1.0
min_level: -90
max_level: -46
...
eval:
...
```
Then in the training script,
```python
import copy
from ruamel.yaml import YAML
yaml = YAML(typ="safe")
with open(model_config) as f:
params = yaml.load(f)
# Train Config for Data Loader
train_dl_params = copy.deepcopy(params["AudioToTextDataLayer"])
train_dl_params.update(params["AudioToTextDataLayer"]["train"])
del train_dl_params["train"]
del train_dl_params["eval"]
data_layer_train = nemo_asr.AudioToTextDataLayer(
...,
**train_dl_params,
)
# Evaluation Config for Data Loader
eval_dl_params = copy.deepcopy(params["AudioToTextDataLayer"])
eval_dl_params.update(params["AudioToTextDataLayer"]["eval"])
del eval_dl_params["train"]
del eval_dl_params["eval"]
data_layer_eval = nemo_asr.AudioToTextDataLayer(
...,
**eval_dl_params,
)
```
# Registering your own Augmentations
To register custom augmentations to obtain the above convenience of
the declaring the augmentations in YAML, you can put additional keys in
`perturbation_types` dictionary as follows.
```python
from nemo.collections.asr.parts import perturb
# Define your own perturbation here
class CustomPerturbation(perturb.Perturbation):
...
perturb.register_perturbation(name_of_perturbation, CustomPerturbation)
```
Args:
augmenter: AudioAugmentor object or
dictionary of str -> kwargs (dict) which is parsed and used
to initialize an AudioAugmentor.
Note: It is crucial that each individual augmentation has
a keyword `prob`, that defines a float probability in the
the range [0, 1] of this augmentation being applied.
If this keyword is not present, then the augmentation is
disabled and a warning is logged.
Returns: AudioAugmentor object
"""
if augmenter is None:
return None
if isinstance(augmenter, AudioAugmentor):
return augmenter
if not type(augmenter) in {dict, DictConfig}:
raise ValueError("Cannot parse augmenter. Must be a dict or an AudioAugmentor object ")
if isinstance(augmenter, DictConfig):
augmenter = OmegaConf.to_container(augmenter, resolve=True)
augmenter = copy.deepcopy(augmenter)
augmentations = []
for augment_name, augment_kwargs in augmenter.items():
prob = augment_kwargs.get('prob', None)
if prob is None:
raise KeyError(
f'Augmentation "{augment_name}" will not be applied as '
f'keyword argument "prob" was not defined for this augmentation.'
)
else:
_ = augment_kwargs.pop('prob')
if prob < 0.0 or prob > 1.0:
raise ValueError("`prob` must be a float value between 0 and 1.")
try:
augmentation = perturbation_types[augment_name](**augment_kwargs)
augmentations.append([prob, augmentation])
except KeyError:
raise KeyError(f"Invalid perturbation name. Allowed values : {perturbation_types.keys()}")
augmenter = AudioAugmentor(perturbations=augmentations)
return augmenter
class AugmentationDataset(IterableDataset):
"""
A class that loads tarred audio files and cycles over the files in the dataset.
Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset/AudioToBPEDataset),
as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
contain the information for one audio file, including at least the transcript and name of the audio
file within the tarball.
Valid formats for the audio_tar_filepaths argument include:
(1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
(2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].
Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference.
This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements.
Supported opening braces - { <=> (, [, < and the special tag _OP_.
Supported closing braces - } <=> ), ], > and the special tag _CL_.
For SLURM based tasks, we suggest the use of the special tags for ease of use.
See the WebDataset documentation for more information about accepted data and input formats.
"""
def __init__(self, manifest_path: str, tar_filepaths: Union[str, List[str]], shuffle_n: int = 128):
self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True)
if isinstance(tar_filepaths, str):
# Replace '(' and '[' with '{'
brace_keys_open = ['(', '[', '<', '_OP_']
for bkey in brace_keys_open:
if bkey in tar_filepaths:
tar_filepaths = tar_filepaths.replace(bkey, "{")
# Replace ')' and ']' with '}'
brace_keys_close = [')', ']', '>', '_CL_']
for bkey in brace_keys_close:
if bkey in tar_filepaths:
tar_filepaths = tar_filepaths.replace(bkey, "}")
self.audio_dataset = wd.WebDataset(tar_filepaths)
if shuffle_n > 0:
self.audio_dataset = self.audio_dataset.shuffle(shuffle_n)
else:
logging.info("WebDataset will not shuffle files within the tar files.")
self.audio_dataset = self.audio_dataset.rename(audio='wav', key='__key__').to_tuple('audio', 'key')
self.audio_iter = iter(self.audio_dataset)
def __len__(self):
return len(self._manifest)
def __iter__(self):
return self
def __next__(self):
while True:
try:
audio_bytes, audio_filename = next(self.audio_iter)
except StopIteration:
self.audio_iter = iter(self.audio_dataset)
audio_bytes, audio_filename = next(self.audio_iter)
file_id, _ = os.path.splitext(os.path.basename(audio_filename))
# Convert audio bytes to IO stream for processing (for SoundFile to read)
audio_file = io.BytesIO(audio_bytes)
return audio_file, file_id