2020-12-01 16:37:16 +01:00

141 lines
5.2 KiB
Python

import pickle
from pathlib import Path
import multiprocessing as mp
import librosa as librosa
from torch.utils.data import Dataset, ConcatDataset
import torch
from tqdm import tqdm
import variables as V
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
from ml_lib.modules.util import F_x
class Urban8K(Dataset):
@property
def sample_shape(self):
return self[0][0].shape
@property
def _fingerprint(self):
return str(self._mel_transform)
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
use_preprocessed=True, audio_segment_len=62, audio_hop_len=30, num_worker=mp.cpu_count(),
**_):
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert fold in range(1, 11)
super(Urban8K, self).__init__()
self.data_root = Path(data_root) / 'UrbanSound8K'
self.setting = setting
self.num_worker = num_worker
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
self.use_preprocessed = use_preprocessed
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
transforms = transforms or F_x(in_shape=None)
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
while True:
if not self.use_preprocessed:
self._pre_process()
try:
self._dataset = ConcatDataset(
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
segment_len=audio_segment_len, hop_len=audio_hop_len,
label=self._labels[key]['label']
) for key in self._labels.keys()]
)
break
except IOError:
self.use_preprocessed = False
pass
def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
if int(fold) == self.fold:
key = slice_file_name.replace('.wav', '')
labeldict[key] = dict(label=int(class_id), fold=int(fold))
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
for mel_file in self._mel_folder.rglob(f'{key}_*'):
try:
mel_file.unlink(missing_ok=True)
except FileNotFoundError:
pass
return labeldict
def __len__(self):
return len(self._dataset)
def _pre_process(self):
print('Preprocessing Mel Files....')
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(self._labels)) as pbar:
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
pbar.update()
def _build_mel(self, filename):
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
if not mel_file:
raw_sample, sr = librosa.core.load(wav_file)
mel_sample = self._mel_transform(raw_sample)
m, n = mel_sample.shape
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
self._mel_folder.mkdir(exist_ok=True, parents=True)
with mel_file.open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
# print(f"Already existed.. Skipping {filename}")
mel_file = mel_file[0]
with mel_file.open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample, mel_file
def __getitem__(self, item):
transformed_samples, label = self._dataset[item]
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label