Audio IO Ready
This commit is contained in:
@ -1,3 +1,4 @@
|
|||||||
|
import pickle
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -5,23 +6,28 @@ import librosa as librosa
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
import variables as V
|
import variables as V
|
||||||
|
from ml_lib.modules.utils import F_x
|
||||||
|
|
||||||
|
|
||||||
class BinaryMasks(Dataset):
|
class BinaryMasksDataset(Dataset):
|
||||||
_to_label = defaultdict(-1)
|
_to_label = defaultdict(lambda: -1)
|
||||||
_to_label['clear'] = V.CLEAR
|
_to_label['clear'] = V.CLEAR
|
||||||
_to_label['mask'] = V.MASK
|
_to_label['mask'] = V.MASK
|
||||||
|
settings = ['test', 'devel', 'train']
|
||||||
|
|
||||||
def __init__(self, data_root, setting):
|
def __init__(self, data_root, setting, transforms=None):
|
||||||
assert isinstance(setting, str)
|
assert isinstance(setting, str), f'Setting has to be a string, but was: {self.settings}.'
|
||||||
assert setting in ['test', 'devel', 'train']
|
assert setting in self.settings, f'Setting must match one of: {self.settings}.'
|
||||||
super(BinaryMasks, self).__init__()
|
assert callable(transforms) or None, f'Transforms has to be callable, but was: {transforms}'
|
||||||
|
super(BinaryMasksDataset, self).__init__()
|
||||||
|
|
||||||
self.data_root = Path(data_root)
|
self.data_root = Path(data_root)
|
||||||
self.setting = setting
|
self.setting = setting
|
||||||
|
self._transforms = transforms or F_x()
|
||||||
self._labels = self._build_labels()
|
self._labels = self._build_labels()
|
||||||
self._wav_folder = self.data_root / 'wav'
|
self._wav_folder = self.data_root / 'wav'
|
||||||
self._files = list(sorted(self._labels.keys()))
|
self._wav_files = list(sorted(self._labels.keys()))
|
||||||
|
self._mel_folder = self.data_root / 'raw_mel'
|
||||||
|
|
||||||
def _build_labels(self):
|
def _build_labels(self):
|
||||||
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
||||||
@ -39,15 +45,16 @@ class BinaryMasks(Dataset):
|
|||||||
return len(self._labels)
|
return len(self._labels)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
key = self._files[item]
|
key = self._wav_files[item]
|
||||||
sample = librosa.core.load(self._wav_folder / self._files[key])
|
filename = key[:-4] + '.pik'
|
||||||
|
|
||||||
|
if not (self._mel_folder / filename).exists():
|
||||||
|
raw_sample, sr = librosa.core.load(self._wav_folder / self._wav_files[item])
|
||||||
|
transformed_sample = self._transforms(raw_sample)
|
||||||
|
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||||
|
with (self._mel_folder / filename).open(mode='wb') as f:
|
||||||
|
pickle.dump(transformed_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
with (self._mel_folder / filename).open(mode='rb') as f:
|
||||||
|
sample = pickle.load(f, fix_imports=True)
|
||||||
label = self._labels[key]
|
label = self._labels[key]
|
||||||
return sample, label
|
return sample, label
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
18
main.py
18
main.py
@ -11,7 +11,9 @@ import torch
|
|||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision.transforms import Compose, ToTensor
|
||||||
|
|
||||||
|
from ml_lib.audio_toolset.audio_io import Melspectogram, NormalizeLocal, AutoPadToShape
|
||||||
from ml_lib.modules.utils import LightningBaseModule
|
from ml_lib.modules.utils import LightningBaseModule
|
||||||
from ml_lib.utils.logging import Logger
|
from ml_lib.utils.logging import Logger
|
||||||
from util.config import MConfig
|
from util.config import MConfig
|
||||||
@ -93,11 +95,15 @@ def run_lightning_loop(config_obj):
|
|||||||
|
|
||||||
# Dataset and Dataloaders
|
# Dataset and Dataloaders
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Train Dataset
|
# Transforms
|
||||||
from datasets.binar_masks import BinaryMasks
|
transforms = Compose([Melspectogram(), ToTensor(), NormalizeLocal()])
|
||||||
dataset = BinaryMasks(config_obj.data.root, setting='train')
|
# Datasets
|
||||||
# Train Dataloader
|
from datasets.binar_masks import BinaryMasksDataset
|
||||||
dataloader = DataLoader(dataset)
|
train_dataset = BinaryMasksDataset(config_obj.data.root, setting='train', transforms=transforms)
|
||||||
|
val_dataset = BinaryMasksDataset(config_obj.data.root, setting='devel', transforms=transforms)
|
||||||
|
# Dataloaders
|
||||||
|
train_dataloader = DataLoader(train_dataset)
|
||||||
|
val_dataloader = DataLoader(val_dataset)
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -121,7 +127,7 @@ def run_lightning_loop(config_obj):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Train It
|
# Train It
|
||||||
trainer.fit(model)
|
trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader)
|
||||||
|
|
||||||
# Save the last state & all parameters
|
# Save the last state & all parameters
|
||||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||||
|
Reference in New Issue
Block a user