2020-12-17 08:02:29 +01:00

105 lines
3.1 KiB
Python

import pickle
from pathlib import Path
from typing import Union
from abc import ABC
import variables as V
from torch.utils.data import Dataset
class BaseAudioToMelDataset(Dataset, ABC):
@property
def task_type(self):
return self._task_type
@property
def classes(self):
return V.multi_classes
@property
def n_classes(self):
return V.N_CLASS_binary if self.task_type == V.TASK_OPTION_binary else V.N_CLASS_multi
@property
def sample_shape(self):
return self[0][0].shape
@property
def _fingerprint(self):
raise NotImplementedError
return str(self._mel_transform)
# Data Structures
@property
def mel_folder(self):
return self.data_root / 'mel'
@property
def wav_folder(self):
return self.data_root / self._wav_folder_name
@property
def _container_ext(self):
return '.mel'
def __init__(self, data_root: Union[str, Path], task_type, mel_kwargs,
mel_augmentations=None, audio_augmentations=None, reset=False,
wav_folder_name='wav', **_):
super(BaseAudioToMelDataset, self).__init__()
# Dataset Parameters
self.data_root = Path(data_root)
self._wav_folder_name = wav_folder_name
self.reset = reset
self.mel_kwargs = mel_kwargs
# Transformations
self.mel_augmentations = mel_augmentations
self.audio_augmentations = audio_augmentations
self._task_type = task_type
# Find all raw files and turn generator to persistent list:
self._wav_files = list(self.wav_folder.rglob('*.wav'))
# Build the Dataset
self._dataset = self._build_dataset()
def __len__(self):
raise NotImplementedError
def __getitem__(self, item):
raise NotImplementedError
def _build_dataset(self):
raise NotImplementedError
def _check_reset_and_clean_up(self, reset):
all_mel_folders = set([str(x.parent).replace(self._wav_folder_name, 'mel') for x in self._wav_files])
for mel_folder in all_mel_folders:
param_storage = Path(mel_folder) / 'data_params.pik'
param_storage.parent.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
this_reset = reset
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
this_reset = True
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
this_reset = True
if this_reset:
all_mel_files = self.mel_folder.rglob(f'*{self._container_ext}')
for mel_file in all_mel_files:
mel_file.unlink()