Dataset rdy
This commit is contained in:
@ -15,11 +15,17 @@ import multiprocessing as mp
|
||||
|
||||
|
||||
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
||||
class_names = {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||
|
||||
|
||||
class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
|
||||
class_names = {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
||||
return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape
|
||||
|
||||
@property
|
||||
def mel_folder(self):
|
||||
return self.root / 'mel_folder'
|
||||
@ -28,14 +34,15 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
def wav_folder(self):
|
||||
return self.root / 'wav'
|
||||
|
||||
def __init__(self, root, batch_size, num_worker, sr, n_mels, n_fft, hop_length,
|
||||
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length,
|
||||
sample_segment_len=40, sample_hop_len=15):
|
||||
super(PrimatesLibrosaDatamodule, self).__init__()
|
||||
self.sample_hop_len = sample_hop_len
|
||||
self.sample_segment_len = sample_segment_len
|
||||
self.num_worker = num_worker
|
||||
self.num_worker = num_worker or 1
|
||||
self.batch_size = batch_size
|
||||
self.root = Path(root) / 'primates'
|
||||
self.root = Path(data_root) / 'primates'
|
||||
self.mel_length_in_seconds = 0.7
|
||||
|
||||
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
|
||||
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
|
||||
@ -70,13 +77,17 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
|
||||
def _build_subdataset(self, row, build=False):
|
||||
slice_file_name, class_name = row.strip().split(',')
|
||||
class_id = class_names.get(class_name, -1)
|
||||
class_id = self.class_names.get(class_name, -1)
|
||||
audio_file_path = self.wav_folder / slice_file_name
|
||||
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
|
||||
kwargs = self.__dict__
|
||||
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
|
||||
kwargs.update(mel_augmentations=self.utility_transforms)
|
||||
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
|
||||
|
||||
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
|
||||
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
||||
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
|
||||
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
||||
if build:
|
||||
assert mel_dataset.build_mel()
|
||||
@ -101,7 +112,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
chunksize=chunksize)
|
||||
for sub_dataset in results.get():
|
||||
dataset.append(sub_dataset)
|
||||
tqdm.update() # FIXME: will i ever get this to work?
|
||||
update() # FIXME: will i ever get this to work?
|
||||
datasets[data_option] = ConcatDataset(dataset)
|
||||
self.datasets = datasets
|
||||
return datasets
|
||||
|
Reference in New Issue
Block a user