import multiprocessing as mp from collections import defaultdict from pathlib import Path from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler from torchvision.transforms import Compose, RandomApply from ml_lib.audio_toolset.audio_io import NormalizeLocal from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel from ml_lib.utils.equal_sampler import EqualSampler from ml_lib.utils.tools import add_argparse_args from ml_lib.utils.transforms import ToTensor data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel] class CompareBase(_BaseDataModule): @property def class_names(self): return {key: val for val, key in enumerate(self._class_names)} @property def n_classes(self): return len(self.class_names) @property def shape(self): return 1, int(self.mel_kwargs['n_mels']), int(self.sample_segment_length) @property def mel_folder(self): return Path(f'{self.root}_mel_folder') @property def wav_folder(self): return self.root / 'wav' def __init__(self, sub_dataset_name, class_names, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None, random_apply_chance=0.5, target_mel_length_in_seconds=1, loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3): super(CompareBase, self).__init__() self.sampler = sampler self.samplers = None self.num_worker = num_worker or 1 self.batch_size = batch_size self.root = Path(data_root) / sub_dataset_name self._class_names = class_names self.mel_length_in_seconds = target_mel_length_in_seconds # Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length) target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr'] self.sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1 # Utility self.utility_transforms = Compose([ NormalizeLocal(), ToTensor() ]) # Data Augmentations self.random_apply_chance = random_apply_chance self.mel_augmentations = Compose([ RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance), RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance), RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance), RandomApply([MaskAug(mask_ratio)], p=random_apply_chance), self.utility_transforms]) def train_dataloader(self): return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True, sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size) # Validation Dataloader def val_dataloader(self): return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, batch_size=self.batch_size, pin_memory=False, num_workers=self.num_worker) # Test Dataloader def test_dataloader(self): return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False, batch_size=self.batch_size, pin_memory=False, num_workers=self.num_worker) def _build_subdataset(self, row, build=False, data_option=None): slice_file_name, class_name = row.strip().split(',') if data_option is not None: if data_option not in slice_file_name: return None, -1, 'no_file' class_id = self.class_names.get(class_name, -1) audio_file_path = self.wav_folder / slice_file_name # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin kwargs = self.__dict__ if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]): kwargs.update(mel_augmentations=self.utility_transforms) # DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End kwargs.update(sample_segment_len=self.sample_segment_length, sample_hop_len=self.sample_segment_length//2) mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs) if build: assert mel_dataset.build_mel() return mel_dataset, class_id, slice_file_name def manual_setup(self): datasets = dict() with (Path(self.root) / 'lab' / 'labels.csv') as label_csv_file: if label_csv_file.exists(): lab_file = label_csv_file.name else: lab_file = None for data_option in data_options: if lab_file is None: lab_file = f'{data_option}.csv' elif lab_file is not None: if any([x in lab_file for x in data_options]): lab_file = f'{data_option}.csv' dataset = self._load_from_file(lab_file, data_option, rebuild=True) datasets[data_option] = ConcatDataset(dataset) print(f'{data_option}-dataset prepared.') self.datasets = datasets return datasets def prepare_data(self, *args, rebuild=False, subsets=None, **kwargs): datasets = dict() samplers = dict() weights = dict() with (Path(self.root) / 'lab' / 'labels.csv') as label_csv_file: if label_csv_file.exists(): lab_file = label_csv_file.name else: lab_file = None for data_option in data_options: if subsets is not None: if data_option not in subsets: print(f'{data_option} skipped...') continue if lab_file is None: lab_file = f'{data_option}.csv' elif lab_file is not None: if any([x in lab_file for x in data_options]): lab_file = f'{data_option}.csv' dataset = self._load_from_file(lab_file, data_option, rebuild=rebuild) datasets[data_option] = ConcatDataset(dataset) print(f'{data_option}-dataset set up!') # Build Weighted Sampler for train and val if data_option in [DATA_OPTION_train]: if self.sampler == EqualSampler.__name__: class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx] for class_idx in range(len(self.class_names)) ] samplers[data_option] = EqualSampler(class_idxs) elif self.sampler == WeightedRandomSampler.__name__: class_counts = defaultdict(lambda: 0) for _, __, label in datasets[data_option]: class_counts[label] += 1 len_largest_class = max(class_counts.values()) weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))] ############################################################################## weights[data_option] = [weights[data_option][datasets[data_option][i][-1]] for i in range(len(datasets[data_option]))] samplers[data_option] = WeightedRandomSampler(weights[data_option], len_largest_class * len(self.class_names)) else: samplers[data_option] = None self.datasets = datasets self.samplers = samplers print(f'Dataset {self.__class__.__name__} setup done.') return datasets def _load_from_file(self, lab_file, data_option, rebuild=False): with open(Path(self.root) / 'lab' / lab_file, mode='r') as f: # Exclude the header _ = next(f) all_rows = list(f) chunksize = len(all_rows) // max(self.num_worker, 1) dataset = list() with mp.Pool(processes=self.num_worker) as pool: from itertools import repeat results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(rebuild, len(all_rows)), repeat(data_option, len(all_rows)) ), chunksize=chunksize) for sub_dataset in results.get(): if sub_dataset[0] is not None: dataset.append(sub_dataset[0]) return dataset def purge(self): import shutil shutil.rmtree(self.mel_folder, ignore_errors=True) print('Mel Folder has been recursively deleted') print(f'Folder still exists: {self.mel_folder.exists()}') return not self.mel_folder.exists()