Mel_Vision_Transformer_ComP.../datasets/primates_librosa_datamodule.py
2021-03-04 12:01:09 +01:00

167 lines
7.4 KiB
Python

import multiprocessing as mp
from collections import defaultdict
from pathlib import Path
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torchvision.transforms import Compose, RandomApply
from tqdm import tqdm
from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
from ml_lib.utils.equal_sampler import EqualSampler
from ml_lib.utils.transforms import ToTensor
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
class PrimatesLibrosaDatamodule(_BaseDataModule):
class_names = {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
@property
def shape(self):
return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape
@property
def mel_folder(self):
return self.root / 'mel_folder'
@property
def wav_folder(self):
return self.root / 'wav'
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
sample_segment_len=40, sample_hop_len=15, random_apply_chance=0.5,
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
super(PrimatesLibrosaDatamodule, self).__init__()
self.sampler = sampler
self.samplers = None
self.sample_hop_len = sample_hop_len
self.sample_segment_len = sample_segment_len
self.num_worker = num_worker or 1
self.batch_size = batch_size
self.root = Path(data_root) / 'primates'
self.mel_length_in_seconds = 0.7
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
# Utility
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
# Data Augmentations
self.random_apply_chance = random_apply_chance
self.mel_augmentations = Compose([
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
self.utility_transforms])
def train_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], num_workers=self.num_worker, pin_memory=True,
sampler=self.samplers[DATA_OPTION_devel], batch_size=self.batch_size)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
batch_size=self.batch_size, pin_memory=True,
num_workers=self.num_worker)
def _build_subdataset(self, row, build=False):
slice_file_name, class_name = row.strip().split(',')
class_id = self.class_names.get(class_name, -1)
audio_file_path = self.wav_folder / slice_file_name
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
kwargs = self.__dict__
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
kwargs.update(mel_augmentations=self.utility_transforms)
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
if build:
assert mel_dataset.build_mel()
return mel_dataset, class_id, slice_file_name
def prepare_data(self, *args, **kwargs):
datasets = dict()
for data_option in data_options:
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
# Exclude the header
_ = next(f)
all_rows = list(f)
chunksize = len(all_rows) // max(self.num_worker, 1)
dataset = list()
with mp.Pool(processes=self.num_worker) as pool:
from itertools import repeat
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
chunksize=chunksize)
for sub_dataset in results.get():
dataset.append(sub_dataset[0])
datasets[data_option] = ConcatDataset(dataset)
self.datasets = datasets
return datasets
def setup(self, stag=None):
datasets = dict()
samplers = dict()
weights = dict()
for data_option in data_options:
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
# Exclude the header
_ = next(f)
all_rows = list(f)
dataset = list()
for row in all_rows:
mel_dataset, class_id, _ = self._build_subdataset(row)
dataset.append(mel_dataset)
datasets[data_option] = ConcatDataset(dataset)
# Build Weighted Sampler for train and val
if data_option in [DATA_OPTION_train, DATA_OPTION_devel]:
if self.sampler == EqualSampler.__name__:
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
for class_idx in range(len(self.class_names))
]
samplers[data_option] = EqualSampler(class_idxs)
elif self.sampler == WeightedRandomSampler.__name__:
class_counts = defaultdict(lambda: 0)
for _, __, label in datasets[data_option]:
class_counts[label] += 1
len_largest_class = max(class_counts.values())
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
for i in range(len(datasets[data_option]))]
samplers[data_option] = WeightedRandomSampler(weights[data_option],
len_largest_class * len(self.class_names))
else:
samplers[data_option] = None
self.datasets = datasets
self.samplers = samplers
return datasets
def purge(self):
import shutil
shutil.rmtree(self.mel_folder, ignore_errors=True)
print('Mel Folder has been recursively deleted')
print(f'Folder still exists: {self.mel_folder.exists()}')
return not self.mel_folder.exists()