Transformer running

This commit is contained in:
Steffen Illium
2021-03-04 12:01:09 +01:00
parent 7edd3834a1
commit ad254dae92
14 changed files with 679 additions and 134 deletions

View File

@ -1,8 +1,8 @@
from multiprocessing.pool import ApplyResult
import multiprocessing as mp
from collections import defaultdict
from pathlib import Path
from typing import List
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torchvision.transforms import Compose, RandomApply
from tqdm import tqdm
@ -10,9 +10,8 @@ from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
from ml_lib.utils.equal_sampler import EqualSampler
from ml_lib.utils.transforms import ToTensor
import multiprocessing as mp
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
@ -34,11 +33,16 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
def wav_folder(self):
return self.root / 'wav'
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length,
sample_segment_len=40, sample_hop_len=15):
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
sample_segment_len=40, sample_hop_len=15, random_apply_chance=0.5,
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
super(PrimatesLibrosaDatamodule, self).__init__()
self.sampler = sampler
self.samplers = None
self.sample_hop_len = sample_hop_len
self.sample_segment_len = sample_segment_len
self.num_worker = num_worker or 1
self.batch_size = batch_size
self.root = Path(data_root) / 'primates'
@ -51,23 +55,22 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
# Data Augmentations
self.random_apply_chance = random_apply_chance
self.mel_augmentations = Compose([
# ToDo: HP Search this parameters, make it adjustable from outside
RandomApply([NoiseInjection(0.2)], p=0.3),
RandomApply([LoudnessManipulator(0.5)], p=0.3),
RandomApply([ShiftTime(0.4)], p=0.3),
RandomApply([MaskAug(0.2)], p=0.3),
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
self.utility_transforms])
def train_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True,
batch_size=self.batch_size, pin_memory=True,
num_workers=self.num_worker)
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True,
batch_size=self.batch_size, num_workers=self.num_worker)
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], num_workers=self.num_worker, pin_memory=True,
sampler=self.samplers[DATA_OPTION_devel], batch_size=self.batch_size)
# Test Dataloader
def test_dataloader(self):
@ -79,6 +82,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
slice_file_name, class_name = row.strip().split(',')
class_id = self.class_names.get(class_name, -1)
audio_file_path = self.wav_folder / slice_file_name
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
kwargs = self.__dict__
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
@ -91,7 +95,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
if build:
assert mel_dataset.build_mel()
return mel_dataset
return mel_dataset, class_id, slice_file_name
def prepare_data(self, *args, **kwargs):
datasets = dict()
@ -103,22 +107,21 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
chunksize = len(all_rows) // max(self.num_worker, 1)
dataset = list()
with mp.Pool(processes=self.num_worker) as pool:
pbar = tqdm(total=len(all_rows))
def update():
pbar.update(chunksize)
from itertools import repeat
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
chunksize=chunksize)
for sub_dataset in results.get():
dataset.append(sub_dataset)
update() # FIXME: will i ever get this to work?
dataset.append(sub_dataset[0])
datasets[data_option] = ConcatDataset(dataset)
self.datasets = datasets
return datasets
def setup(self, stag=None):
datasets = dict()
samplers = dict()
weights = dict()
for data_option in data_options:
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
# Exclude the header
@ -126,7 +129,38 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
all_rows = list(f)
dataset = list()
for row in all_rows:
dataset.append(self._build_subdataset(row))
mel_dataset, class_id, _ = self._build_subdataset(row)
dataset.append(mel_dataset)
datasets[data_option] = ConcatDataset(dataset)
# Build Weighted Sampler for train and val
if data_option in [DATA_OPTION_train, DATA_OPTION_devel]:
if self.sampler == EqualSampler.__name__:
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
for class_idx in range(len(self.class_names))
]
samplers[data_option] = EqualSampler(class_idxs)
elif self.sampler == WeightedRandomSampler.__name__:
class_counts = defaultdict(lambda: 0)
for _, __, label in datasets[data_option]:
class_counts[label] += 1
len_largest_class = max(class_counts.values())
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
for i in range(len(datasets[data_option]))]
samplers[data_option] = WeightedRandomSampler(weights[data_option],
len_largest_class * len(self.class_names))
else:
samplers[data_option] = None
self.datasets = datasets
self.samplers = samplers
return datasets
def purge(self):
import shutil
shutil.rmtree(self.mel_folder, ignore_errors=True)
print('Mel Folder has been recursively deleted')
print(f'Folder still exists: {self.mel_folder.exists()}')
return not self.mel_folder.exists()