masks_augments_compare-21/util/module_mixins.py

148 lines
5.5 KiB
Python

from collections import defaultdict
from abc import ABC
from argparse import Namespace
import sklearn
import torch
import numpy as np
from torch.nn import L1Loss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
from ml_lib.utils.transforms import ToTensor
import variables as V
class BaseOptimizerMixin:
def configure_optimizers(self):
opt = Adam(params=self.parameters(), lr=self.params.lr)
if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt
def on_train_end(self):
for opt in self.trainer.optimizers:
if isinstance(opt, SWA):
opt.swap_swa_sgd()
def on_epoch_end(self):
if False: # FIXME: Pass a new parameter to model args.
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
opt.state = defaultdict(dict)
class BaseTrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
loss = self.criterion(y, batch_y)
return dict(loss=loss)
def training_epoch_end(self, outputs):
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
return summary_dict
class BaseValMixin:
absolute_loss = L1Loss()
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
val_bce_loss = self.criterion(y, batch_y)
val_abs_loss = self.absolute_loss(y, batch_y)
return dict(val_bce_loss=val_bce_loss, val_abs_loss=val_abs_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y
)
def validation_epoch_end(self, outputs):
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred = (y_pred >= 0.5).astype(np.float32)
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
summary_dict['log'].update(uar_score=uar_score)
return summary_dict
class BinaryMaskDatasetFunction:
def build_dataset(self):
# Dataset
# =============================================================================
# Mel Transforms
mel_transforms = Compose([
# Audio to Mel Transformations
AudioToMel(n_mels=self.params.n_mels), MelToImage()])
# Data Augmentations
aug_transforms = Compose([
RandomApply([
NoiseInjection(self.params.noise_ratio),
LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio)], p=0.5),
# Utility
NormalizeLocal(), ToTensor()
])
val_transforms = Compose([NormalizeLocal(), ToTensor()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
dataset = Namespace(
**dict(
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, mixup=self.params.mixup,
mel_transforms=mel_transforms, transforms=aug_transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=val_transforms),
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
mel_transforms=mel_transforms, transforms=val_transforms),
)
)
return dataset
class BaseDataloadersMixin(ABC):
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.params.batch_size,
num_workers=self.params.worker)