paper preperations and notebooks, optuna callbacks, new plots
This commit is contained in:
@@ -0,0 +1,136 @@
|
||||
from pathlib import Path
|
||||
|
||||
import multiprocessing as mp
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset
|
||||
from torchvision.transforms import RandomApply, ToTensor, Compose
|
||||
|
||||
from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
||||
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
|
||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_devel, DATA_OPTION_train, DATA_OPTION_test
|
||||
|
||||
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
||||
|
||||
try:
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
class UrbanSound8K(_BaseDataModule):
|
||||
|
||||
_class_names = ['air_conditioner', 'car_horn', 'children_playing', 'dog_bark', 'drilling',
|
||||
'engine_idling', 'gun_shot', 'jackhammer', 'siren', 'street_music']
|
||||
|
||||
@property
|
||||
def class_names(self):
|
||||
return {key: val for val, key in enumerate(self._class_names)}
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return len(self.class_names)
|
||||
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return self[0][1].shape
|
||||
|
||||
# Data Structures
|
||||
@property
|
||||
def mel_folder(self):
|
||||
return self.data_root / 'mel'
|
||||
|
||||
@property
|
||||
def wav_folder(self):
|
||||
return self.data_root / self._wav_folder_name
|
||||
|
||||
@property
|
||||
def _container_ext(self):
|
||||
return '.mel'
|
||||
|
||||
def __init__(self, data_root, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
|
||||
random_apply_chance=0.5, target_mel_length_in_seconds=1, fold=1, setting=DATA_OPTION_train,
|
||||
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3, rebuild=False):
|
||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||
assert fold in range(1, 11) if isinstance(fold, int) else all([f in range(1, 11) for f in fold])
|
||||
super(UrbanSound8K, self).__init__()
|
||||
self.num_worker = num_worker or 1
|
||||
self.sampler = sampler
|
||||
|
||||
# Dataset Paramters
|
||||
self.fold = fold if isinstance(fold, list) else [fold]
|
||||
|
||||
# Dataset Parameters
|
||||
self.data_root = Path(data_root) / self.__class__.__name__
|
||||
self._wav_folder_name = 'audio'
|
||||
|
||||
self.mel_length_in_seconds = target_mel_length_in_seconds
|
||||
|
||||
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
|
||||
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
|
||||
|
||||
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
|
||||
self.sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
||||
|
||||
# Utility
|
||||
self.utility_transforms = Compose([
|
||||
NormalizeLocal(),
|
||||
ToTensor()
|
||||
])
|
||||
|
||||
# Data Augmentations
|
||||
self.random_apply_chance = random_apply_chance
|
||||
self.mel_augmentations = Compose([
|
||||
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
|
||||
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
|
||||
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
|
||||
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
|
||||
self.utility_transforms])
|
||||
|
||||
# Find all raw files and turn generator to persistent list:
|
||||
self._wav_files = list(self.wav_folder.rglob('*.wav'))
|
||||
|
||||
# Build the Dataset
|
||||
self._dataset = self._build_dataset(rebuild)
|
||||
|
||||
def _build_subdataset(self, row, build=False):
|
||||
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
|
||||
fold, class_id = (int(x) for x in (fold, class_id))
|
||||
if int(fold) in self.fold:
|
||||
audio_file_path = self.wav_folder / f'fold{fold}' / slice_file_name
|
||||
kwargs = dict(sample_segment_len=self.sample_segment_length,
|
||||
sample_hop_len=self.sample_segment_length // 2)
|
||||
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, mel_kwargs=self.mel_kwargs, **kwargs)
|
||||
if build:
|
||||
assert mel_dataset.build_mel()
|
||||
return mel_dataset, class_id, slice_file_name
|
||||
else:
|
||||
return None
|
||||
|
||||
def _build_dataset(self, build=False):
|
||||
dataset = list()
|
||||
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
all_rows = list(f)
|
||||
chunksize = len(all_rows) // max(self.num_worker, 1)
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
from itertools import repeat
|
||||
results = pool.starmap_async(self._build_subdataset,
|
||||
zip(all_rows,
|
||||
repeat(build, len(all_rows))
|
||||
),
|
||||
chunksize=chunksize)
|
||||
for sub_dataset in results.get():
|
||||
if sub_dataset is not None:
|
||||
if sub_dataset[0] is not None:
|
||||
dataset.append(sub_dataset[0])
|
||||
return ConcatDataset(dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
def __getitem__(self, item):
|
||||
file_path, transformed_samples, label = self._dataset[item]
|
||||
label = torch.as_tensor(label, dtype=torch.int)
|
||||
return file_path, transformed_samples, label
|
||||
@@ -0,0 +1,62 @@
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from datasets.urban_8k import UrbanSound8K
|
||||
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_train, DATA_OPTION_devel
|
||||
from ml_lib.utils.tools import add_argparse_args
|
||||
|
||||
|
||||
class Urban8KLibrosaDatamodule(_BaseDataModule):
|
||||
|
||||
def __init__(self, batch_size, num_worker, data_root, sr, n_mels, n_fft, hop_length,
|
||||
sampler=None, val_fold=9,
|
||||
random_apply_chance=0.5, target_mel_length_in_seconds=1,
|
||||
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3, **kwargs):
|
||||
super(Urban8KLibrosaDatamodule, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.num_worker = num_worker
|
||||
|
||||
self.val_fold = val_fold
|
||||
|
||||
self.kwargs = kwargs
|
||||
self.kwargs.update(data_root=data_root, num_worker=num_worker,
|
||||
sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, sampler=sampler,
|
||||
random_apply_chance=random_apply_chance,
|
||||
target_mel_length_in_seconds=target_mel_length_in_seconds,
|
||||
loudness_ratio=loudness_ratio, shift_ratio=shift_ratio, noise_ratio=noise_ratio,
|
||||
mask_ratio=mask_ratio)
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(UrbanSound8K, parent_parser)
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args, **kwargs):
|
||||
val_fold = kwargs.get('val_fold', 10)
|
||||
kwargs.update(val_fold=val_fold)
|
||||
return super(Urban8KLibrosaDatamodule, cls).from_argparse_args(args, **kwargs)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True,
|
||||
batch_size=self.batch_size, pin_memory=True,
|
||||
num_workers=self.num_worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True,
|
||||
batch_size=self.batch_size, num_workers=self.num_worker)
|
||||
|
||||
def prepare_data(self, stag=None):
|
||||
# Train Datasset
|
||||
self.datasets[DATA_OPTION_train] = UrbanSound8K(fold=[x for x in list(range(1, 11)) if x != self.val_fold],
|
||||
**self.kwargs)
|
||||
# Devel Datasset
|
||||
self.datasets[DATA_OPTION_devel] = UrbanSound8K(fold=self.val_fold, **self.kwargs)
|
||||
|
||||
def manual_setup(self):
|
||||
UrbanSound8K(fold=[x for x in list(range(1, 11))], rebuild=True, **self.kwargs)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user