masks_augments_compare-21/util/module_mixins.py
2020-05-14 23:08:36 +02:00

189 lines
7.8 KiB
Python

from collections import defaultdict
from abc import ABC
from argparse import Namespace
import sklearn
import torch
import numpy as np
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler
from torchcontrib.optim import SWA
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import Speed
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
from ml_lib.modules.utils import LightningBaseModule
from ml_lib.utils.transforms import ToTensor
import variables as V
class BaseOptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=1e-7)
if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt
def on_train_end(self):
assert isinstance(self, LightningBaseModule)
for opt in self.trainer.optimizers:
if isinstance(opt, SWA):
opt.swap_swa_sgd()
def on_epoch_end(self):
assert isinstance(self, LightningBaseModule)
if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
opt.state = defaultdict(dict)
class BaseTrainMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
bce_loss = self.bce_loss(y, batch_y)
return dict(loss=bce_loss)
def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
return summary_dict
class BaseValMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
val_bce_loss = self.bce_loss(y, batch_y)
return dict(val_bce_loss=val_bce_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict())
for output_idx, output in enumerate(outputs):
keys = list(output[0].keys())
ident = '' if output_idx == 0 else '_train'
summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
for output in output]))
for key in keys if 'loss' in key}
)
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in output]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in output]).squeeze().cpu().numpy()
y_pred = (y_pred >= 0.5).astype(np.float32)
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
summary_dict['log'].update({f'uar{ident}_score': uar_score})
return summary_dict
class BinaryMaskDatasetMixin:
def build_dataset(self):
assert isinstance(self, LightningBaseModule)
# Dataset
# =============================================================================
# Mel Transforms
mel_transforms_train = Compose([
# Audio to Mel Transformations
Speed(speed_factor=self.params.speed_factor, max_ratio=self.params.speed_ratio),
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
hop_length=self.params.hop_length),
MelToImage()])
mel_transforms = Compose([
# Audio to Mel Transformations
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
hop_length=self.params.hop_length), MelToImage()])
# Data Augmentations
aug_transforms = Compose([
RandomApply([
NoiseInjection(self.params.noise_ratio),
LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio),
MaskAug(self.params.mask_ratio),
], p=0.6),
# Utility
NormalizeLocal(), ToTensor()
])
val_transforms = Compose([NormalizeLocal(), ToTensor()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
mixup=self.params.mixup,
mel_transforms=mel_transforms_train, transforms=aug_transforms),
# VALIDATION DATASET
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
mel_transforms=mel_transforms, transforms=val_transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=val_transforms),
# TEST DATASET
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
mel_transforms=mel_transforms, transforms=val_transforms),
)
)
return dataset
class BaseDataloadersMixin(ABC):
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
assert isinstance(self, LightningBaseModule)
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
sampler = None
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Test Dataloader
def test_dataloader(self):
assert isinstance(self, LightningBaseModule)
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Validation Dataloader
def val_dataloader(self):
assert isinstance(self, LightningBaseModule)
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.params.batch_size, num_workers=self.params.worker)
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
batch_size=self.params.batch_size, shuffle=False)
return [val_dataloader, train_dataloader]