133 lines
5.5 KiB
Python
133 lines
5.5 KiB
Python
from multiprocessing.pool import ApplyResult
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
from torch.utils.data import DataLoader, ConcatDataset
|
|
from torchvision.transforms import Compose, RandomApply
|
|
from tqdm import tqdm
|
|
|
|
from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
|
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
|
|
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
|
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
|
|
from ml_lib.utils.transforms import ToTensor
|
|
import multiprocessing as mp
|
|
|
|
|
|
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
|
|
|
|
|
class PrimatesLibrosaDatamodule(_BaseDataModule):
|
|
|
|
class_names = {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
|
|
|
@property
|
|
def shape(self):
|
|
|
|
return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape
|
|
|
|
@property
|
|
def mel_folder(self):
|
|
return self.root / 'mel_folder'
|
|
|
|
@property
|
|
def wav_folder(self):
|
|
return self.root / 'wav'
|
|
|
|
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length,
|
|
sample_segment_len=40, sample_hop_len=15):
|
|
super(PrimatesLibrosaDatamodule, self).__init__()
|
|
self.sample_hop_len = sample_hop_len
|
|
self.sample_segment_len = sample_segment_len
|
|
self.num_worker = num_worker or 1
|
|
self.batch_size = batch_size
|
|
self.root = Path(data_root) / 'primates'
|
|
self.mel_length_in_seconds = 0.7
|
|
|
|
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
|
|
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
|
|
|
|
# Utility
|
|
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
|
|
|
|
# Data Augmentations
|
|
self.mel_augmentations = Compose([
|
|
# ToDo: HP Search this parameters, make it adjustable from outside
|
|
RandomApply([NoiseInjection(0.2)], p=0.3),
|
|
RandomApply([LoudnessManipulator(0.5)], p=0.3),
|
|
RandomApply([ShiftTime(0.4)], p=0.3),
|
|
RandomApply([MaskAug(0.2)], p=0.3),
|
|
self.utility_transforms])
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True,
|
|
batch_size=self.batch_size, pin_memory=True,
|
|
num_workers=self.num_worker)
|
|
|
|
# Validation Dataloader
|
|
def val_dataloader(self):
|
|
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True,
|
|
batch_size=self.batch_size, num_workers=self.num_worker)
|
|
|
|
# Test Dataloader
|
|
def test_dataloader(self):
|
|
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
|
|
batch_size=self.batch_size, pin_memory=True,
|
|
num_workers=self.num_worker)
|
|
|
|
def _build_subdataset(self, row, build=False):
|
|
slice_file_name, class_name = row.strip().split(',')
|
|
class_id = self.class_names.get(class_name, -1)
|
|
audio_file_path = self.wav_folder / slice_file_name
|
|
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
|
|
kwargs = self.__dict__
|
|
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
|
|
kwargs.update(mel_augmentations=self.utility_transforms)
|
|
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
|
|
|
|
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
|
|
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
|
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
|
|
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
|
if build:
|
|
assert mel_dataset.build_mel()
|
|
return mel_dataset
|
|
|
|
def prepare_data(self, *args, **kwargs):
|
|
datasets = dict()
|
|
for data_option in data_options:
|
|
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
|
# Exclude the header
|
|
_ = next(f)
|
|
all_rows = list(f)
|
|
chunksize = len(all_rows) // max(self.num_worker, 1)
|
|
dataset = list()
|
|
with mp.Pool(processes=self.num_worker) as pool:
|
|
pbar = tqdm(total=len(all_rows))
|
|
|
|
def update():
|
|
pbar.update(chunksize)
|
|
from itertools import repeat
|
|
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
|
|
chunksize=chunksize)
|
|
for sub_dataset in results.get():
|
|
dataset.append(sub_dataset)
|
|
update() # FIXME: will i ever get this to work?
|
|
datasets[data_option] = ConcatDataset(dataset)
|
|
self.datasets = datasets
|
|
return datasets
|
|
|
|
def setup(self, stag=None):
|
|
datasets = dict()
|
|
for data_option in data_options:
|
|
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
|
# Exclude the header
|
|
_ = next(f)
|
|
all_rows = list(f)
|
|
dataset = list()
|
|
for row in all_rows:
|
|
dataset.append(self._build_subdataset(row))
|
|
datasets[data_option] = ConcatDataset(dataset)
|
|
self.datasets = datasets
|
|
return datasets
|