Dataset rdy

This commit is contained in:
Steffen Illium
2021-02-16 10:18:04 +01:00
parent 151b22a2c3
commit 7edd3834a1
11 changed files with 350 additions and 15 deletions

View File

@ -15,11 +15,17 @@ import multiprocessing as mp
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
class_names = {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
class PrimatesLibrosaDatamodule(_BaseDataModule):
class_names = {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
@property
def shape(self):
return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape
@property
def mel_folder(self):
return self.root / 'mel_folder'
@ -28,14 +34,15 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
def wav_folder(self):
return self.root / 'wav'
def __init__(self, root, batch_size, num_worker, sr, n_mels, n_fft, hop_length,
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length,
sample_segment_len=40, sample_hop_len=15):
super(PrimatesLibrosaDatamodule, self).__init__()
self.sample_hop_len = sample_hop_len
self.sample_segment_len = sample_segment_len
self.num_worker = num_worker
self.num_worker = num_worker or 1
self.batch_size = batch_size
self.root = Path(root) / 'primates'
self.root = Path(data_root) / 'primates'
self.mel_length_in_seconds = 0.7
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
@ -70,13 +77,17 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
def _build_subdataset(self, row, build=False):
slice_file_name, class_name = row.strip().split(',')
class_id = class_names.get(class_name, -1)
class_id = self.class_names.get(class_name, -1)
audio_file_path = self.wav_folder / slice_file_name
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
kwargs = self.__dict__
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
kwargs.update(mel_augmentations=self.utility_transforms)
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
if build:
assert mel_dataset.build_mel()
@ -101,7 +112,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
chunksize=chunksize)
for sub_dataset in results.get():
dataset.append(sub_dataset)
tqdm.update() # FIXME: will i ever get this to work?
update() # FIXME: will i ever get this to work?
datasets[data_option] = ConcatDataset(dataset)
self.datasets = datasets
return datasets