CCS intergration dataloader
This commit is contained in:
@@ -17,7 +17,13 @@ data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
||||
|
||||
class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
|
||||
class_names = {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||
@property
|
||||
def class_names(self):
|
||||
return {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return len(self.class_names)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
@@ -33,19 +39,16 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
return self.root / 'wav'
|
||||
|
||||
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
|
||||
sample_segment_len=40, sample_hop_len=15, random_apply_chance=0.5,
|
||||
target_mel_length_in_seconds=0.7, random_apply_chance=0.5,
|
||||
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
|
||||
super(PrimatesLibrosaDatamodule, self).__init__()
|
||||
self.sampler = sampler
|
||||
self.samplers = None
|
||||
|
||||
self.sample_hop_len = sample_hop_len
|
||||
self.sample_segment_len = sample_segment_len
|
||||
|
||||
self.num_worker = num_worker or 1
|
||||
self.batch_size = batch_size
|
||||
self.root = Path(data_root) / 'primates'
|
||||
self.mel_length_in_seconds = 0.7
|
||||
self.target_mel_length_in_seconds = target_mel_length_in_seconds
|
||||
|
||||
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
|
||||
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
|
||||
@@ -89,7 +92,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
kwargs.update(mel_augmentations=self.utility_transforms)
|
||||
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
|
||||
|
||||
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
|
||||
target_frames = self.target_mel_length_in_seconds * self.mel_kwargs['sr']
|
||||
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
||||
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
|
||||
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user