2021-03-27 16:39:07 +01:00

210 lines
8.9 KiB
Python

import multiprocessing as mp
from collections import defaultdict
from pathlib import Path
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
from ml_lib.utils.equal_sampler import EqualSampler
from ml_lib.utils.tools import add_argparse_args
from ml_lib.utils.transforms import ToTensor
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
class CompareBase(_BaseDataModule):
@property
def class_names(self):
return {key: val for val, key in enumerate(self._class_names)}
@property
def n_classes(self):
return len(self.class_names)
@property
def shape(self):
return 1, int(self.mel_kwargs['n_mels']), int(self.sample_segment_length)
@property
def mel_folder(self):
return Path(f'{self.root}_mel_folder')
@property
def wav_folder(self):
return self.root / 'wav'
def __init__(self, sub_dataset_name, class_names, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
random_apply_chance=0.5, target_mel_length_in_seconds=1,
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
super(CompareBase, self).__init__()
self.sampler = sampler
self.samplers = None
self.num_worker = num_worker or 1
self.batch_size = batch_size
self.root = Path(data_root) / sub_dataset_name
self._class_names = class_names
self.mel_length_in_seconds = target_mel_length_in_seconds
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
self.sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
# Utility
self.utility_transforms = Compose([
NormalizeLocal(),
ToTensor()
])
# Data Augmentations
self.random_apply_chance = random_apply_chance
self.mel_augmentations = Compose([
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
self.utility_transforms])
def train_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False,
batch_size=self.batch_size, pin_memory=False,
num_workers=self.num_worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
batch_size=self.batch_size, pin_memory=False,
num_workers=self.num_worker)
def _build_subdataset(self, row, build=False, data_option=None):
slice_file_name, class_name = row.strip().split(',')
if data_option is not None:
if data_option not in slice_file_name:
return None, -1, 'no_file'
class_id = self.class_names.get(class_name, -1)
audio_file_path = self.wav_folder / slice_file_name
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
kwargs = self.__dict__
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
kwargs.update(mel_augmentations=self.utility_transforms)
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
kwargs.update(sample_segment_len=self.sample_segment_length, sample_hop_len=self.sample_segment_length//2)
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
if build:
assert mel_dataset.build_mel()
return mel_dataset, class_id, slice_file_name
def manual_setup(self):
datasets = dict()
with (Path(self.root) / 'lab' / 'labels.csv') as label_csv_file:
if label_csv_file.exists():
lab_file = label_csv_file.name
else:
lab_file = None
for data_option in data_options:
if lab_file is not None:
if any([x in lab_file for x in data_options]):
lab_file = f'{data_option}.csv'
dataset = self._load_from_file(lab_file, data_option, rebuild=True)
datasets[data_option] = ConcatDataset(dataset)
print(f'{data_option}-dataset prepared.')
self.datasets = datasets
return datasets
def prepare_data(self, *args, rebuild=False, subsets=None, **kwargs):
datasets = dict()
samplers = dict()
weights = dict()
with (Path(self.root) / 'lab' / 'labels.csv') as label_csv_file:
if label_csv_file.exists():
lab_file = label_csv_file.name
else:
lab_file = None
for data_option in data_options:
if subsets is not None:
if data_option not in subsets:
print(f'{data_option} skipped...')
continue
if lab_file is not None:
if any([x in lab_file for x in data_options]):
lab_file = f'{data_option}.csv'
dataset = self._load_from_file(lab_file, data_option, rebuild=rebuild)
datasets[data_option] = ConcatDataset(dataset)
print(f'{data_option}-dataset set up!')
# Build Weighted Sampler for train and val
if data_option in [DATA_OPTION_train]:
if self.sampler == EqualSampler.__name__:
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
for class_idx in range(len(self.class_names))
]
samplers[data_option] = EqualSampler(class_idxs)
elif self.sampler == WeightedRandomSampler.__name__:
class_counts = defaultdict(lambda: 0)
for _, __, label in datasets[data_option]:
class_counts[label] += 1
len_largest_class = max(class_counts.values())
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
##############################################################################
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
for i in range(len(datasets[data_option]))]
samplers[data_option] = WeightedRandomSampler(weights[data_option],
len_largest_class * len(self.class_names))
else:
samplers[data_option] = None
self.datasets = datasets
self.samplers = samplers
print(f'Dataset {self.__class__.__name__} setup done.')
return datasets
def _load_from_file(self, lab_file, data_option, rebuild=False):
with open(Path(self.root) / 'lab' / lab_file, mode='r') as f:
# Exclude the header
_ = next(f)
all_rows = list(f)
chunksize = len(all_rows) // max(self.num_worker, 1)
dataset = list()
with mp.Pool(processes=self.num_worker) as pool:
from itertools import repeat
results = pool.starmap_async(self._build_subdataset,
zip(all_rows,
repeat(rebuild, len(all_rows)),
repeat(data_option, len(all_rows))
),
chunksize=chunksize)
for sub_dataset in results.get():
if sub_dataset[0] is not None:
dataset.append(sub_dataset[0])
return dataset
def purge(self):
import shutil
shutil.rmtree(self.mel_folder, ignore_errors=True)
print('Mel Folder has been recursively deleted')
print(f'Folder still exists: {self.mel_folder.exists()}')
return not self.mel_folder.exists()