105 lines
3.1 KiB
Python
105 lines
3.1 KiB
Python
import pickle
|
|
from pathlib import Path
|
|
from typing import Union
|
|
from abc import ABC
|
|
|
|
import variables as V
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
class BaseAudioToMelDataset(Dataset, ABC):
|
|
|
|
@property
|
|
def task_type(self):
|
|
return self._task_type
|
|
|
|
@property
|
|
def classes(self):
|
|
return V.multi_classes
|
|
|
|
@property
|
|
def n_classes(self):
|
|
return V.N_CLASS_binary if self.task_type == V.TASK_OPTION_binary else V.N_CLASS_multi
|
|
|
|
@property
|
|
def sample_shape(self):
|
|
return self[0][0].shape
|
|
|
|
@property
|
|
def _fingerprint(self):
|
|
raise NotImplementedError
|
|
return str(self._mel_transform)
|
|
|
|
# Data Structures
|
|
@property
|
|
def mel_folder(self):
|
|
return self.data_root / 'mel'
|
|
|
|
@property
|
|
def wav_folder(self):
|
|
return self.data_root / self._wav_folder_name
|
|
|
|
@property
|
|
def _container_ext(self):
|
|
return '.mel'
|
|
|
|
def __init__(self, data_root: Union[str, Path], task_type, mel_kwargs,
|
|
mel_augmentations=None, audio_augmentations=None, reset=False,
|
|
wav_folder_name='wav', **_):
|
|
super(BaseAudioToMelDataset, self).__init__()
|
|
|
|
# Dataset Parameters
|
|
self.data_root = Path(data_root)
|
|
self._wav_folder_name = wav_folder_name
|
|
self.reset = reset
|
|
self.mel_kwargs = mel_kwargs
|
|
|
|
# Transformations
|
|
self.mel_augmentations = mel_augmentations
|
|
self.audio_augmentations = audio_augmentations
|
|
self._task_type = task_type
|
|
|
|
# Find all raw files and turn generator to persistent list:
|
|
self._wav_files = list(self.wav_folder.rglob('*.wav'))
|
|
|
|
# Build the Dataset
|
|
self._dataset = self._build_dataset()
|
|
|
|
|
|
def __len__(self):
|
|
raise NotImplementedError
|
|
|
|
def __getitem__(self, item):
|
|
raise NotImplementedError
|
|
|
|
def _build_dataset(self):
|
|
raise NotImplementedError
|
|
|
|
def _check_reset_and_clean_up(self, reset):
|
|
all_mel_folders = set([str(x.parent).replace(self._wav_folder_name, 'mel') for x in self._wav_files])
|
|
for mel_folder in all_mel_folders:
|
|
param_storage = Path(mel_folder) / 'data_params.pik'
|
|
param_storage.parent.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
pik_data = param_storage.read_bytes()
|
|
fingerprint = pickle.loads(pik_data)
|
|
if fingerprint == self._fingerprint:
|
|
this_reset = reset
|
|
else:
|
|
print('Diverging parameters were found; Refreshing...')
|
|
param_storage.unlink()
|
|
pik_data = pickle.dumps(self._fingerprint)
|
|
param_storage.write_bytes(pik_data)
|
|
this_reset = True
|
|
|
|
except FileNotFoundError:
|
|
pik_data = pickle.dumps(self._fingerprint)
|
|
param_storage.write_bytes(pik_data)
|
|
this_reset = True
|
|
|
|
if this_reset:
|
|
all_mel_files = self.mel_folder.rglob(f'*{self._container_ext}')
|
|
for mel_file in all_mel_files:
|
|
mel_file.unlink()
|
|
|