from multiprocessing.pool import ApplyResult from pathlib import Path from typing import List from torch.utils.data import DataLoader, ConcatDataset from torchvision.transforms import Compose, RandomApply from tqdm import tqdm from ml_lib.audio_toolset.audio_io import NormalizeLocal from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel from ml_lib.utils.transforms import ToTensor import multiprocessing as mp data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel] class PrimatesLibrosaDatamodule(_BaseDataModule): class_names = {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])} @property def shape(self): return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape @property def mel_folder(self): return self.root / 'mel_folder' @property def wav_folder(self): return self.root / 'wav' def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sample_segment_len=40, sample_hop_len=15): super(PrimatesLibrosaDatamodule, self).__init__() self.sample_hop_len = sample_hop_len self.sample_segment_len = sample_segment_len self.num_worker = num_worker or 1 self.batch_size = batch_size self.root = Path(data_root) / 'primates' self.mel_length_in_seconds = 0.7 # Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length) # Utility self.utility_transforms = Compose([NormalizeLocal(), ToTensor()]) # Data Augmentations self.mel_augmentations = Compose([ # ToDo: HP Search this parameters, make it adjustable from outside RandomApply([NoiseInjection(0.2)], p=0.3), RandomApply([LoudnessManipulator(0.5)], p=0.3), RandomApply([ShiftTime(0.4)], p=0.3), RandomApply([MaskAug(0.2)], p=0.3), self.utility_transforms]) def train_dataloader(self): return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_worker) # Validation Dataloader def val_dataloader(self): return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True, batch_size=self.batch_size, num_workers=self.num_worker) # Test Dataloader def test_dataloader(self): return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_worker) def _build_subdataset(self, row, build=False): slice_file_name, class_name = row.strip().split(',') class_id = self.class_names.get(class_name, -1) audio_file_path = self.wav_folder / slice_file_name # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin kwargs = self.__dict__ if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]): kwargs.update(mel_augmentations=self.utility_transforms) # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr'] sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1 kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2) mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs) if build: assert mel_dataset.build_mel() return mel_dataset def prepare_data(self, *args, **kwargs): datasets = dict() for data_option in data_options: with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f: # Exclude the header _ = next(f) all_rows = list(f) chunksize = len(all_rows) // max(self.num_worker, 1) dataset = list() with mp.Pool(processes=self.num_worker) as pool: pbar = tqdm(total=len(all_rows)) def update(): pbar.update(chunksize) from itertools import repeat results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))), chunksize=chunksize) for sub_dataset in results.get(): dataset.append(sub_dataset) update() # FIXME: will i ever get this to work? datasets[data_option] = ConcatDataset(dataset) self.datasets = datasets return datasets def setup(self, stag=None): datasets = dict() for data_option in data_options: with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f: # Exclude the header _ = next(f) all_rows = list(f) dataset = list() for row in all_rows: dataset.append(self._build_subdataset(row)) datasets[data_option] = ConcatDataset(dataset) self.datasets = datasets return datasets