From 93103aba01d5ce883efdcfa57f175fcea7e1134a Mon Sep 17 00:00:00 2001 From: Si11ium Date: Thu, 17 Dec 2020 11:00:42 +0100 Subject: [PATCH] Repair of ML Lib -> Transformations back to np from torch --- audio_toolset/audio_to_mel_dataset.py | 68 --------------------------- audio_toolset/mel_augmentation.py | 31 ++++++------ experiments.py | 67 ++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 85 deletions(-) create mode 100644 experiments.py diff --git a/audio_toolset/audio_to_mel_dataset.py b/audio_toolset/audio_to_mel_dataset.py index eb2d7f4..6f045f0 100644 --- a/audio_toolset/audio_to_mel_dataset.py +++ b/audio_toolset/audio_to_mel_dataset.py @@ -93,71 +93,3 @@ class LibrosaAudioToMelDataset(_AudioToMelDataset): pass return self.mel_file_path.exists() - - -import torchaudio -if sys.platform =='windows': - torchaudio.set_audio_backend('soundfile') -else: - torchaudio.set_audio_backend('sox_io') - - -class PyTorchAudioToMelDataset(_AudioToMelDataset): - - @property - def audio_file_duration(self): - info_obj = torchaudio.info(self.audio_path) - return info_obj.num_frames / info_obj.sample_rate - - @property - def sampling_rate(self): - return self.mel_kwargs['sample_rate'] - - def __init__(self, audio_file_path, *args, **kwargs): - super(PyTorchAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs) - - audio_file_path = Path(audio_file_path) - # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate) - - from torchaudio.transforms import MelSpectrogram - self._mel_transform = Compose([MelSpectrogram(**self.mel_kwargs), - MelToImage() - ]) - - def _build_mel(self): - if self.reset: - self.mel_file_path.unlink(missing_ok=True) - if not self.mel_file_path.exists(): - self.mel_file_path.parent.mkdir(parents=True, exist_ok=True) - lock_file = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock')) - lock_file.touch(exist_ok=False) - - try: - audio_sample, sample_rate = torchaudio.load(self.audio_path) - except RuntimeError: - import soundfile - - data, samplerate = soundfile.read(self.audio_path) - # sf.available_formats() - # sf.available_subtypes() - soundfile.write(self.audio_path, data, samplerate, subtype='PCM_32') - - audio_sample, sample_rate = torchaudio.load(self.audio_path) - if sample_rate != self.sampling_rate: - resample = torchaudio.transforms.Resample(orig_freq=int(sample_rate), new_freq=int(self.sampling_rate)) - audio_sample = resample(audio_sample) - if audio_sample.shape[0] > 1: - # Transform Stereo to Mono - audio_sample = audio_sample.mean(dim=0, keepdim=True) - mel_sample = self._mel_transform(audio_sample) - with self.mel_file_path.open('wb') as mel_file: - pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL) - lock_file.unlink() - else: - # print(f"Already existed.. Skipping {filename}") - # mel_file = mel_file - pass - - # with mel_file.open(mode='rb') as f: - # mel_sample = pickle.load(f, fix_imports=True) - return self.mel_file_path.exists() diff --git a/audio_toolset/mel_augmentation.py b/audio_toolset/mel_augmentation.py index ccb16cf..9f7516c 100644 --- a/audio_toolset/mel_augmentation.py +++ b/audio_toolset/mel_augmentation.py @@ -1,4 +1,3 @@ -import torch import numpy as np from ml_lib.utils.transforms import _BaseTransformation @@ -13,10 +12,12 @@ class NoiseInjection(_BaseTransformation): self.sigma = sigma self.noise_factor = noise_factor - def __call__(self, x): + def __call__(self, x: np.ndarray): if self.noise_factor: - noise = torch.normal(self.mu, self.sigma, size=x.shape, device=x.device) * self.noise_factor + noise = np.random.normal(self.mu, self.sigma, size=x.shape) * self.noise_factor augmented_data = x + x * noise + # Cast back to same data type + augmented_data = augmented_data.astype(x.dtype) return augmented_data else: return x @@ -32,7 +33,9 @@ class LoudnessManipulator(_BaseTransformation): def __call__(self, x): if self.max_factor: - augmented_data = x + x * (torch.rand(1, device=x.device) * self.max_factor) + augmented_data = x + x * (np.random.random() * self.max_factor) + # Cast back to same data type + augmented_data = augmented_data.astype(x.dtype) return augmented_data else: return x @@ -49,18 +52,17 @@ class ShiftTime(_BaseTransformation): self.max_shift_ratio = max_shift_ratio self.shift_direction = shift_direction.lower() - def __call__(self, x): + def __call__(self, x: np.ndarray): if self.max_shift_ratio: - shift = torch.randint(max(int(self.max_shift_ratio * x.shape[-1]), 1), (1,)).item() + shift = np.random.randint(max(int(self.max_shift_ratio * x.shape[-1]), 1)) if self.shift_direction == 'right': shift = -1 * shift elif self.shift_direction == 'any': - # The ugly pytorch alternative - # direction = [-1, 1][torch.multinomial(torch.as_tensor([1, 2]).float(), 1).item()] direction = np.asscalar(np.random.choice([1, -1], 1)) shift = direction * shift - augmented_data = torch.roll(x, shift, dims=-1) + augmented_data = np.roll(x, shift) # Set to silence for heading/ tailing + shift = int(shift) if shift > 0: augmented_data[:shift, :] = 0 else: @@ -89,20 +91,15 @@ class MaskAug(_BaseTransformation): else (duration_ratio_max, duration_ratio_max) def __call__(self, x): - assert x.ndim == 3, "This function was made to wotk with two-dimensional inputs" for dim in (self.w_idx, self.h_idx): if self.duration_ratio_max[dim]: - if dim == self.w_idx and x.shape[dim] == 0: - print(x) start = np.asscalar(np.random.choice(x.shape[dim], 1)) v_max = int(x.shape[dim] * self.duration_ratio_max[dim]) - size = torch.randint(0, v_max, (1,)).item() + size = np.asscalar(np.random.randint(0, v_max, 1)) end = int(min(start + size, x.shape[dim])) size = end - start if dim == self.w_idx: - mask = torch.randn(size=(x.shape[self.h_idx], size), device=x.device) if self.mask_with_noise else 0 - x[:, :, start:end] = mask + x[:, start:end] = np.random.random((x.shape[self.h_idx], size)) if self.mask_with_noise else 0 else: - mask = torch.randn((size, x.shape[self.w_idx]), device=x.device) if self.mask_with_noise else 0 - x[:, start:end, :] = mask + x[start:end, :] = np.random.random((size, x.shape[self.w_idx])) if self.mask_with_noise else 0 return x diff --git a/experiments.py b/experiments.py new file mode 100644 index 0000000..66d4171 --- /dev/null +++ b/experiments.py @@ -0,0 +1,67 @@ + +import torchaudio +if sys.platform =='windows': + torchaudio.set_audio_backend('soundfile') +else: + torchaudio.set_audio_backend('sox_io') + + +class PyTorchAudioToMelDataset(_AudioToMelDataset): + + @property + def audio_file_duration(self): + info_obj = torchaudio.info(self.audio_path) + return info_obj.num_frames / info_obj.sample_rate + + @property + def sampling_rate(self): + return self.mel_kwargs['sample_rate'] + + def __init__(self, audio_file_path, *args, **kwargs): + super(PyTorchAudioToMelDataset, self).__init__(audio_file_path, *args, **kwargs) + + audio_file_path = Path(audio_file_path) + # audio_file, sampling_rate = librosa.load(self.audio_path, sr=sampling_rate) + + from torchaudio.transforms import MelSpectrogram + self._mel_transform = Compose([MelSpectrogram(**self.mel_kwargs), + MelToImage() + ]) + + def _build_mel(self): + if self.reset: + self.mel_file_path.unlink(missing_ok=True) + if not self.mel_file_path.exists(): + self.mel_file_path.parent.mkdir(parents=True, exist_ok=True) + lock_file = Path(str(self.mel_file_path).replace(self.mel_file_path.suffix, '.lock')) + lock_file.touch(exist_ok=False) + + try: + audio_sample, sample_rate = torchaudio.load(self.audio_path) + except RuntimeError: + import soundfile + + data, samplerate = soundfile.read(self.audio_path) + # sf.available_formats() + # sf.available_subtypes() + soundfile.write(self.audio_path, data, samplerate, subtype='PCM_32') + + audio_sample, sample_rate = torchaudio.load(self.audio_path) + if sample_rate != self.sampling_rate: + resample = torchaudio.transforms.Resample(orig_freq=int(sample_rate), new_freq=int(self.sampling_rate)) + audio_sample = resample(audio_sample) + if audio_sample.shape[0] > 1: + # Transform Stereo to Mono + audio_sample = audio_sample.mean(dim=0, keepdim=True) + mel_sample = self._mel_transform(audio_sample) + with self.mel_file_path.open('wb') as mel_file: + pickle.dump(mel_sample, mel_file, protocol=pickle.HIGHEST_PROTOCOL) + lock_file.unlink() + else: + # print(f"Already existed.. Skipping {filename}") + # mel_file = mel_file + pass + + # with mel_file.open(mode='rb') as f: + # mel_sample = pickle.load(f, fix_imports=True) + return self.mel_file_path.exists()