torchaudio testing

This commit is contained in:
Si11ium
2020-12-17 08:02:29 +01:00
parent 95dcf22f3d
commit 68431b848e
13 changed files with 578 additions and 418 deletions

104
datasets/base_dataset.py Normal file
View File

@@ -0,0 +1,104 @@
import pickle
from pathlib import Path
from typing import Union
from abc import ABC
import variables as V
from torch.utils.data import Dataset
class BaseAudioToMelDataset(Dataset, ABC):
@property
def task_type(self):
return self._task_type
@property
def classes(self):
return V.multi_classes
@property
def n_classes(self):
return V.N_CLASS_binary if self.task_type == V.TASK_OPTION_binary else V.N_CLASS_multi
@property
def sample_shape(self):
return self[0][0].shape
@property
def _fingerprint(self):
raise NotImplementedError
return str(self._mel_transform)
# Data Structures
@property
def mel_folder(self):
return self.data_root / 'mel'
@property
def wav_folder(self):
return self.data_root / self._wav_folder_name
@property
def _container_ext(self):
return '.mel'
def __init__(self, data_root: Union[str, Path], task_type, mel_kwargs,
mel_augmentations=None, audio_augmentations=None, reset=False,
wav_folder_name='wav', **_):
super(BaseAudioToMelDataset, self).__init__()
# Dataset Parameters
self.data_root = Path(data_root)
self._wav_folder_name = wav_folder_name
self.reset = reset
self.mel_kwargs = mel_kwargs
# Transformations
self.mel_augmentations = mel_augmentations
self.audio_augmentations = audio_augmentations
self._task_type = task_type
# Find all raw files and turn generator to persistent list:
self._wav_files = list(self.wav_folder.rglob('*.wav'))
# Build the Dataset
self._dataset = self._build_dataset()
def __len__(self):
raise NotImplementedError
def __getitem__(self, item):
raise NotImplementedError
def _build_dataset(self):
raise NotImplementedError
def _check_reset_and_clean_up(self, reset):
all_mel_folders = set([str(x.parent).replace(self._wav_folder_name, 'mel') for x in self._wav_files])
for mel_folder in all_mel_folders:
param_storage = Path(mel_folder) / 'data_params.pik'
param_storage.parent.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
this_reset = reset
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
this_reset = True
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
this_reset = True
if this_reset:
all_mel_files = self.mel_folder.rglob(f'*{self._container_ext}')
for mel_file in all_mel_files:
mel_file.unlink()

View File

@@ -20,17 +20,20 @@ class BinaryMasksDataset(Dataset):
@property
def _fingerprint(self):
return dict(**self._mel_kwargs, normalize=self.normalize)
return dict(**self._mel_kwargs if self._mel_kwargs else dict())
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
use_preprocessed=True):
use_preprocessed=True, mel_kwargs=None):
self.stretch = stretch_dataset
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
super(BinaryMasksDataset, self).__init__()
self.task = V.TASK_OPTION_binary
self.data_root = Path(data_root) / 'ComParE2020_Mask'
self.setting = setting
self._mel_kwargs = mel_kwargs
self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel'
self.container_ext = '.pik'

View File

@@ -1,140 +1,78 @@
import pickle
from pathlib import Path
import multiprocessing as mp
from typing import Union, List
import librosa as librosa
from torch.utils.data import Dataset, ConcatDataset
import multiprocessing as mp
from torch.utils.data import ConcatDataset
import torch
from tqdm import tqdm
import variables as V
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
from ml_lib.modules.util import F_x
from datasets.base_dataset import BaseAudioToMelDataset
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset, PyTorchAudioToMelDataset
class Urban8K(Dataset):
try:
torch.multiprocessing.set_sharing_strategy('file_system')
except AttributeError:
pass
@property
def sample_shape(self):
return self[0][0].shape
class Urban8K(BaseAudioToMelDataset):
@property
def _fingerprint(self):
return str(self._mel_transform)
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
use_preprocessed=True, audio_segment_len=62, audio_hop_len=30, num_worker=mp.cpu_count(),
**_):
def __init__(self,
data_root, setting, fold: Union[int, List]=1, num_worker=mp.cpu_count(),
reset=False, sample_segment_len=50, sample_hop_len=20,
**kwargs):
self.num_worker = num_worker
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert fold in range(1, 11)
super(Urban8K, self).__init__()
assert fold in range(1, 11) if isinstance(fold, int) else all([f in range(1, 11) for f in fold])
self.data_root = Path(data_root) / 'UrbanSound8K'
#Dataset Paramters
self.setting = setting
self.num_worker = num_worker
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
self.use_preprocessed = use_preprocessed
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
fold = fold if self.setting != V.DATA_OPTION_test else 10
self.fold = fold if isinstance(fold, list) else [fold]
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
transforms = transforms or F_x(in_shape=None)
self.sample_segment_len = sample_segment_len
self.sample_hop_len = sample_hop_len
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
# Dataset specific super init
super(Urban8K, self).__init__(Path(data_root) / 'UrbanSound8K',
V.TASK_OPTION_multiclass, reset=reset, wav_folder_name='audio', **kwargs
)
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
def _build_subdataset(self, row):
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
fold, class_id = (int(x) for x in (fold, class_id))
if int(fold) in self.fold:
audio_file_path = self.wav_folder / f'fold{fold}' / slice_file_name
return PyTorchAudioToMelDataset(audio_file_path, class_id, **self.__dict__)
else:
return None
while True:
if not self.use_preprocessed:
self._pre_process()
try:
self._dataset = ConcatDataset(
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
segment_len=audio_segment_len, hop_len=audio_hop_len,
label=self._labels[key]['label']
) for key in self._labels.keys()]
)
break
except IOError:
self.use_preprocessed = False
pass
def _build_labels(self):
labeldict = dict()
def _build_dataset(self):
dataset= list()
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
if int(fold) == self.fold:
key = slice_file_name.replace('.wav', '')
labeldict[key] = dict(label=int(class_id), fold=int(fold))
all_rows = list(f)
chunksize = len(all_rows) // max(self.num_worker, 1)
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(all_rows)) as pbar:
for i, sub_dataset in enumerate(
pool.imap_unordered(self._build_subdataset, all_rows, chunksize=chunksize)):
pbar.update()
dataset.append(sub_dataset)
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
for mel_file in self._mel_folder.rglob(f'{key}_*'):
try:
mel_file.unlink(missing_ok=True)
except FileNotFoundError:
pass
return labeldict
dataset = ConcatDataset([x for x in dataset if x is not None])
return dataset
def __len__(self):
return len(self._dataset)
def _pre_process(self):
print('Preprocessing Mel Files....')
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(self._labels)) as pbar:
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
pbar.update()
def _build_mel(self, filename):
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
if not mel_file:
raw_sample, sr = librosa.core.load(wav_file)
mel_sample = self._mel_transform(raw_sample)
m, n = mel_sample.shape
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
self._mel_folder.mkdir(exist_ok=True, parents=True)
with mel_file.open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
# print(f"Already existed.. Skipping {filename}")
mel_file = mel_file[0]
with mel_file.open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample, mel_file
def __getitem__(self, item):
transformed_samples, label = self._dataset[item]
label = torch.as_tensor(label, dtype=torch.float)
label = torch.as_tensor(label, dtype=torch.int)
return transformed_samples, label

View File

@@ -1,140 +0,0 @@
import pickle
from pathlib import Path
import multiprocessing as mp
import librosa as librosa
from torch.utils.data import Dataset, ConcatDataset
import torch
from tqdm import tqdm
import variables as V
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
from ml_lib.modules.util import F_x
class Urban8K_TO(Dataset):
@property
def sample_shape(self):
return self[0][0].shape
@property
def _fingerprint(self):
return str(self._mel_transform)
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
use_preprocessed=True, audio_segment_len=1, audio_hop_len=1, num_worker=mp.cpu_count(),
**_):
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert fold in range(1, 11)
super(Urban8K_TO, self).__init__()
self.data_root = Path(data_root) / 'UrbanSound8K'
self.setting = setting
self.num_worker = num_worker
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
self.use_preprocessed = use_preprocessed
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
transforms = transforms or F_x(in_shape=None)
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
while True:
if not self.use_preprocessed:
self._pre_process()
try:
self._dataset = ConcatDataset(
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
segment_len=audio_segment_len, hop_len=audio_hop_len,
label=self._labels[key]['label']
) for key in self._labels.keys()]
)
break
except IOError:
self.use_preprocessed = False
pass
def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
if int(fold) == self.fold:
key = slice_file_name.replace('.wav', '')
labeldict[key] = dict(label=int(class_id), fold=int(fold))
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
for mel_file in self._mel_folder.rglob(f'{key}_*'):
try:
mel_file.unlink(missing_ok=True)
except FileNotFoundError:
pass
return labeldict
def __len__(self):
return len(self._dataset)
def _pre_process(self):
print('Preprocessing Mel Files....')
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(self._labels)) as pbar:
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
pbar.update()
def _build_mel(self, filename):
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
if not mel_file:
raw_sample, sr = librosa.core.load(wav_file)
mel_sample = self._mel_transform(raw_sample)
m, n = mel_sample.shape
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
self._mel_folder.mkdir(exist_ok=True, parents=True)
with mel_file.open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
# print(f"Already existed.. Skipping {filename}")
mel_file = mel_file[0]
with mel_file.open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample, mel_file
def __getitem__(self, item):
transformed_samples, label = self._dataset[item]
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label