2020-12-17 08:02:29 +01:00

79 lines
3.0 KiB
Python

from pathlib import Path
from typing import Union, List
import multiprocessing as mp
from torch.utils.data import ConcatDataset
import torch
from tqdm import tqdm
import variables as V
from datasets.base_dataset import BaseAudioToMelDataset
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset, PyTorchAudioToMelDataset
try:
torch.multiprocessing.set_sharing_strategy('file_system')
except AttributeError:
pass
class Urban8K(BaseAudioToMelDataset):
def __init__(self,
data_root, setting, fold: Union[int, List]=1, num_worker=mp.cpu_count(),
reset=False, sample_segment_len=50, sample_hop_len=20,
**kwargs):
self.num_worker = num_worker
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert fold in range(1, 11) if isinstance(fold, int) else all([f in range(1, 11) for f in fold])
#Dataset Paramters
self.setting = setting
fold = fold if self.setting != V.DATA_OPTION_test else 10
self.fold = fold if isinstance(fold, list) else [fold]
self.sample_segment_len = sample_segment_len
self.sample_hop_len = sample_hop_len
# Dataset specific super init
super(Urban8K, self).__init__(Path(data_root) / 'UrbanSound8K',
V.TASK_OPTION_multiclass, reset=reset, wav_folder_name='audio', **kwargs
)
def _build_subdataset(self, row):
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
fold, class_id = (int(x) for x in (fold, class_id))
if int(fold) in self.fold:
audio_file_path = self.wav_folder / f'fold{fold}' / slice_file_name
return PyTorchAudioToMelDataset(audio_file_path, class_id, **self.__dict__)
else:
return None
def _build_dataset(self):
dataset= list()
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
# Exclude the header
_ = next(f)
all_rows = list(f)
chunksize = len(all_rows) // max(self.num_worker, 1)
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(all_rows)) as pbar:
for i, sub_dataset in enumerate(
pool.imap_unordered(self._build_subdataset, all_rows, chunksize=chunksize)):
pbar.update()
dataset.append(sub_dataset)
dataset = ConcatDataset([x for x in dataset if x is not None])
return dataset
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
transformed_samples, label = self._dataset[item]
label = torch.as_tensor(label, dtype=torch.int)
return transformed_samples, label