Transformer running
This commit is contained in:
@ -1,8 +1,8 @@
|
||||
from multiprocessing.pool import ApplyResult
|
||||
import multiprocessing as mp
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from torch.utils.data import DataLoader, ConcatDataset
|
||||
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
|
||||
from torchvision.transforms import Compose, RandomApply
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -10,9 +10,8 @@ from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
||||
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
|
||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
|
||||
from ml_lib.utils.equal_sampler import EqualSampler
|
||||
from ml_lib.utils.transforms import ToTensor
|
||||
import multiprocessing as mp
|
||||
|
||||
|
||||
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
||||
|
||||
@ -34,11 +33,16 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
def wav_folder(self):
|
||||
return self.root / 'wav'
|
||||
|
||||
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length,
|
||||
sample_segment_len=40, sample_hop_len=15):
|
||||
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
|
||||
sample_segment_len=40, sample_hop_len=15, random_apply_chance=0.5,
|
||||
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
|
||||
super(PrimatesLibrosaDatamodule, self).__init__()
|
||||
self.sampler = sampler
|
||||
self.samplers = None
|
||||
|
||||
self.sample_hop_len = sample_hop_len
|
||||
self.sample_segment_len = sample_segment_len
|
||||
|
||||
self.num_worker = num_worker or 1
|
||||
self.batch_size = batch_size
|
||||
self.root = Path(data_root) / 'primates'
|
||||
@ -51,23 +55,22 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
|
||||
|
||||
# Data Augmentations
|
||||
self.random_apply_chance = random_apply_chance
|
||||
self.mel_augmentations = Compose([
|
||||
# ToDo: HP Search this parameters, make it adjustable from outside
|
||||
RandomApply([NoiseInjection(0.2)], p=0.3),
|
||||
RandomApply([LoudnessManipulator(0.5)], p=0.3),
|
||||
RandomApply([ShiftTime(0.4)], p=0.3),
|
||||
RandomApply([MaskAug(0.2)], p=0.3),
|
||||
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
|
||||
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
|
||||
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
|
||||
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
|
||||
self.utility_transforms])
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True,
|
||||
batch_size=self.batch_size, pin_memory=True,
|
||||
num_workers=self.num_worker)
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
|
||||
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True,
|
||||
batch_size=self.batch_size, num_workers=self.num_worker)
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], num_workers=self.num_worker, pin_memory=True,
|
||||
sampler=self.samplers[DATA_OPTION_devel], batch_size=self.batch_size)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
@ -79,6 +82,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
slice_file_name, class_name = row.strip().split(',')
|
||||
class_id = self.class_names.get(class_name, -1)
|
||||
audio_file_path = self.wav_folder / slice_file_name
|
||||
|
||||
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
|
||||
kwargs = self.__dict__
|
||||
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
|
||||
@ -91,7 +95,7 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
||||
if build:
|
||||
assert mel_dataset.build_mel()
|
||||
return mel_dataset
|
||||
return mel_dataset, class_id, slice_file_name
|
||||
|
||||
def prepare_data(self, *args, **kwargs):
|
||||
datasets = dict()
|
||||
@ -103,22 +107,21 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
chunksize = len(all_rows) // max(self.num_worker, 1)
|
||||
dataset = list()
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
pbar = tqdm(total=len(all_rows))
|
||||
|
||||
def update():
|
||||
pbar.update(chunksize)
|
||||
from itertools import repeat
|
||||
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
|
||||
chunksize=chunksize)
|
||||
for sub_dataset in results.get():
|
||||
dataset.append(sub_dataset)
|
||||
update() # FIXME: will i ever get this to work?
|
||||
dataset.append(sub_dataset[0])
|
||||
datasets[data_option] = ConcatDataset(dataset)
|
||||
self.datasets = datasets
|
||||
return datasets
|
||||
|
||||
def setup(self, stag=None):
|
||||
datasets = dict()
|
||||
samplers = dict()
|
||||
weights = dict()
|
||||
|
||||
for data_option in data_options:
|
||||
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
@ -126,7 +129,38 @@ class PrimatesLibrosaDatamodule(_BaseDataModule):
|
||||
all_rows = list(f)
|
||||
dataset = list()
|
||||
for row in all_rows:
|
||||
dataset.append(self._build_subdataset(row))
|
||||
mel_dataset, class_id, _ = self._build_subdataset(row)
|
||||
dataset.append(mel_dataset)
|
||||
datasets[data_option] = ConcatDataset(dataset)
|
||||
|
||||
# Build Weighted Sampler for train and val
|
||||
if data_option in [DATA_OPTION_train, DATA_OPTION_devel]:
|
||||
if self.sampler == EqualSampler.__name__:
|
||||
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
|
||||
for class_idx in range(len(self.class_names))
|
||||
]
|
||||
samplers[data_option] = EqualSampler(class_idxs)
|
||||
elif self.sampler == WeightedRandomSampler.__name__:
|
||||
class_counts = defaultdict(lambda: 0)
|
||||
for _, __, label in datasets[data_option]:
|
||||
class_counts[label] += 1
|
||||
len_largest_class = max(class_counts.values())
|
||||
|
||||
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
|
||||
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
|
||||
for i in range(len(datasets[data_option]))]
|
||||
samplers[data_option] = WeightedRandomSampler(weights[data_option],
|
||||
len_largest_class * len(self.class_names))
|
||||
else:
|
||||
samplers[data_option] = None
|
||||
self.datasets = datasets
|
||||
self.samplers = samplers
|
||||
return datasets
|
||||
|
||||
def purge(self):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(self.mel_folder, ignore_errors=True)
|
||||
print('Mel Folder has been recursively deleted')
|
||||
print(f'Folder still exists: {self.mel_folder.exists()}')
|
||||
return not self.mel_folder.exists()
|
||||
|
Reference in New Issue
Block a user