141 lines
5.2 KiB
Python
141 lines
5.2 KiB
Python
import pickle
|
|
from pathlib import Path
|
|
import multiprocessing as mp
|
|
|
|
import librosa as librosa
|
|
from torch.utils.data import Dataset, ConcatDataset
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
import variables as V
|
|
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
|
|
from ml_lib.modules.util import F_x
|
|
|
|
|
|
class Urban8K(Dataset):
|
|
|
|
@property
|
|
def sample_shape(self):
|
|
return self[0][0].shape
|
|
|
|
@property
|
|
def _fingerprint(self):
|
|
return str(self._mel_transform)
|
|
|
|
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
|
|
use_preprocessed=True, audio_segment_len=62, audio_hop_len=30, num_worker=mp.cpu_count(),
|
|
**_):
|
|
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
|
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
|
assert fold in range(1, 11)
|
|
super(Urban8K, self).__init__()
|
|
|
|
self.data_root = Path(data_root) / 'UrbanSound8K'
|
|
self.setting = setting
|
|
self.num_worker = num_worker
|
|
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
|
|
self.use_preprocessed = use_preprocessed
|
|
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
|
|
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
|
|
self.container_ext = '.pik'
|
|
self._mel_transform = mel_transforms
|
|
|
|
self._labels = self._build_labels()
|
|
self._wav_files = list(sorted(self._labels.keys()))
|
|
transforms = transforms or F_x(in_shape=None)
|
|
|
|
param_storage = self._mel_folder / 'data_params.pik'
|
|
self._mel_folder.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
pik_data = param_storage.read_bytes()
|
|
fingerprint = pickle.loads(pik_data)
|
|
if fingerprint == self._fingerprint:
|
|
self.use_preprocessed = use_preprocessed
|
|
else:
|
|
print('Diverging parameters were found; Refreshing...')
|
|
param_storage.unlink()
|
|
pik_data = pickle.dumps(self._fingerprint)
|
|
param_storage.write_bytes(pik_data)
|
|
self.use_preprocessed = False
|
|
|
|
except FileNotFoundError:
|
|
pik_data = pickle.dumps(self._fingerprint)
|
|
param_storage.write_bytes(pik_data)
|
|
self.use_preprocessed = False
|
|
|
|
|
|
while True:
|
|
if not self.use_preprocessed:
|
|
self._pre_process()
|
|
try:
|
|
self._dataset = ConcatDataset(
|
|
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
|
|
segment_len=audio_segment_len, hop_len=audio_hop_len,
|
|
label=self._labels[key]['label']
|
|
) for key in self._labels.keys()]
|
|
)
|
|
break
|
|
except IOError:
|
|
self.use_preprocessed = False
|
|
pass
|
|
|
|
def _build_labels(self):
|
|
labeldict = dict()
|
|
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
|
|
# Exclude the header
|
|
_ = next(f)
|
|
for row in f:
|
|
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
|
|
if int(fold) == self.fold:
|
|
key = slice_file_name.replace('.wav', '')
|
|
labeldict[key] = dict(label=int(class_id), fold=int(fold))
|
|
|
|
# Delete File if one exists.
|
|
if not self.use_preprocessed:
|
|
for key in labeldict.keys():
|
|
for mel_file in self._mel_folder.rglob(f'{key}_*'):
|
|
try:
|
|
mel_file.unlink(missing_ok=True)
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
return labeldict
|
|
|
|
def __len__(self):
|
|
return len(self._dataset)
|
|
|
|
def _pre_process(self):
|
|
print('Preprocessing Mel Files....')
|
|
with mp.Pool(processes=self.num_worker) as pool:
|
|
with tqdm(total=len(self._labels)) as pbar:
|
|
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
|
|
pbar.update()
|
|
|
|
def _build_mel(self, filename):
|
|
|
|
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
|
|
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
|
|
|
|
if not mel_file:
|
|
raw_sample, sr = librosa.core.load(wav_file)
|
|
mel_sample = self._mel_transform(raw_sample)
|
|
m, n = mel_sample.shape
|
|
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
|
|
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
|
with mel_file.open(mode='wb') as f:
|
|
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
else:
|
|
# print(f"Already existed.. Skipping {filename}")
|
|
mel_file = mel_file[0]
|
|
|
|
with mel_file.open(mode='rb') as f:
|
|
mel_sample = pickle.load(f, fix_imports=True)
|
|
return mel_sample, mel_file
|
|
|
|
def __getitem__(self, item):
|
|
transformed_samples, label = self._dataset[item]
|
|
|
|
label = torch.as_tensor(label, dtype=torch.float)
|
|
|
|
return transformed_samples, label
|