79 lines
3.0 KiB
Python
79 lines
3.0 KiB
Python
from pathlib import Path
|
|
from typing import Union, List
|
|
|
|
import multiprocessing as mp
|
|
from torch.utils.data import ConcatDataset
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
import variables as V
|
|
from datasets.base_dataset import BaseAudioToMelDataset
|
|
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset, PyTorchAudioToMelDataset
|
|
|
|
|
|
try:
|
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
|
except AttributeError:
|
|
pass
|
|
|
|
class Urban8K(BaseAudioToMelDataset):
|
|
|
|
def __init__(self,
|
|
data_root, setting, fold: Union[int, List]=1, num_worker=mp.cpu_count(),
|
|
reset=False, sample_segment_len=50, sample_hop_len=20,
|
|
**kwargs):
|
|
self.num_worker = num_worker
|
|
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
|
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
|
assert fold in range(1, 11) if isinstance(fold, int) else all([f in range(1, 11) for f in fold])
|
|
|
|
#Dataset Paramters
|
|
self.setting = setting
|
|
fold = fold if self.setting != V.DATA_OPTION_test else 10
|
|
self.fold = fold if isinstance(fold, list) else [fold]
|
|
|
|
self.sample_segment_len = sample_segment_len
|
|
self.sample_hop_len = sample_hop_len
|
|
|
|
# Dataset specific super init
|
|
super(Urban8K, self).__init__(Path(data_root) / 'UrbanSound8K',
|
|
V.TASK_OPTION_multiclass, reset=reset, wav_folder_name='audio', **kwargs
|
|
)
|
|
|
|
def _build_subdataset(self, row):
|
|
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
|
|
fold, class_id = (int(x) for x in (fold, class_id))
|
|
if int(fold) in self.fold:
|
|
audio_file_path = self.wav_folder / f'fold{fold}' / slice_file_name
|
|
return PyTorchAudioToMelDataset(audio_file_path, class_id, **self.__dict__)
|
|
else:
|
|
return None
|
|
|
|
def _build_dataset(self):
|
|
dataset= list()
|
|
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
|
|
# Exclude the header
|
|
_ = next(f)
|
|
all_rows = list(f)
|
|
chunksize = len(all_rows) // max(self.num_worker, 1)
|
|
with mp.Pool(processes=self.num_worker) as pool:
|
|
with tqdm(total=len(all_rows)) as pbar:
|
|
for i, sub_dataset in enumerate(
|
|
pool.imap_unordered(self._build_subdataset, all_rows, chunksize=chunksize)):
|
|
pbar.update()
|
|
dataset.append(sub_dataset)
|
|
|
|
dataset = ConcatDataset([x for x in dataset if x is not None])
|
|
return dataset
|
|
|
|
def __len__(self):
|
|
return len(self._dataset)
|
|
|
|
|
|
def __getitem__(self, item):
|
|
transformed_samples, label = self._dataset[item]
|
|
|
|
label = torch.as_tensor(label, dtype=torch.int)
|
|
|
|
return transformed_samples, label
|