CCS intergration dataloader

This commit is contained in:
Steffen
2021-03-19 17:17:16 +01:00
parent 6ace861016
commit d4059779c4
8 changed files with 213 additions and 35 deletions

View File

@@ -17,7 +17,13 @@ data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
class PrimatesLibrosaDatamodule(_BaseDataModule):
class_names = {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
@property
def class_names(self):
return {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
@property
def n_classes(self):
return len(self.class_names)
@property
def shape(self):
@@ -33,19 +39,16 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
return self.root / 'wav'
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
sample_segment_len=40, sample_hop_len=15, random_apply_chance=0.5,
target_mel_length_in_seconds=0.7, random_apply_chance=0.5,
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
super(PrimatesLibrosaDatamodule, self).__init__()
self.sampler = sampler
self.samplers = None
self.sample_hop_len = sample_hop_len
self.sample_segment_len = sample_segment_len
self.num_worker = num_worker or 1
self.batch_size = batch_size
self.root = Path(data_root) / 'primates'
self.mel_length_in_seconds = 0.7
self.target_mel_length_in_seconds = target_mel_length_in_seconds
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
@@ -89,7 +92,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
kwargs.update(mel_augmentations=self.utility_transforms)
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
target_frames = self.target_mel_length_in_seconds * self.mel_kwargs['sr']
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)