Parameter Adjustmens and Ensemble Model Implementation

This commit is contained in:
Si11ium
2020-05-08 16:30:55 +02:00
parent c2860b0aed
commit 5e6b0e598f
16 changed files with 648 additions and 313 deletions

View File

@ -1,6 +1,8 @@
from ml_lib.utils.config import Config
from models.binary_classifier import BinaryClassifier
from models.bandwise_binary_classifier import BandwiseBinaryClassifier
from models.conv_classifier import ConvClassifier
from models.bandwise_conv_classifier import BandwiseConvClassifier
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
from models.ensemble import Ensemble
class MConfig(Config):
@ -8,5 +10,8 @@ class MConfig(Config):
@property
def _model_map(self):
return dict(BinaryClassifier=BinaryClassifier,
BandwiseBinaryClassifier=BandwiseBinaryClassifier)
return dict(ConvClassifier=ConvClassifier,
BandwiseConvClassifier=BandwiseConvClassifier,
BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier,
Ensemble=Ensemble,
)

View File

@ -1,11 +0,0 @@
from pathlib import Path
from ml_lib.utils.logging import Logger
class MLogger(Logger):
@property
def outpath(self):
# FIXME: Specify a special path
return Path(self.config.train.outpath)

147
util/module_mixins.py Normal file
View File

@ -0,0 +1,147 @@
from collections import defaultdict
from abc import ABC
from argparse import Namespace
import sklearn
import torch
import numpy as np
from torch.nn import L1Loss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
from ml_lib.utils.transforms import ToTensor
import variables as V
class BaseOptimizerMixin:
def configure_optimizers(self):
opt = Adam(params=self.parameters(), lr=self.params.lr)
if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt
def on_train_end(self):
for opt in self.trainer.optimizers:
if isinstance(opt, SWA):
opt.swap_swa_sgd()
def on_epoch_end(self):
if False: # FIXME: Pass a new parameter to model args.
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
opt.state = defaultdict(dict)
class BaseTrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
loss = self.criterion(y, batch_y)
return dict(loss=loss)
def training_epoch_end(self, outputs):
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
return summary_dict
class BaseValMixin:
absolute_loss = L1Loss()
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
val_bce_loss = self.criterion(y, batch_y)
val_abs_loss = self.absolute_loss(y, batch_y)
return dict(val_bce_loss=val_bce_loss, val_abs_loss=val_abs_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y
)
def validation_epoch_end(self, outputs):
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred = (y_pred >= 0.5).astype(np.float32)
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
summary_dict['log'].update(uar_score=uar_score)
return summary_dict
class BinaryMaskDatasetFunction:
def build_dataset(self):
# Dataset
# =============================================================================
# Mel Transforms
mel_transforms = Compose([
# Audio to Mel Transformations
AudioToMel(n_mels=self.params.n_mels), MelToImage()])
# Data Augmentations
aug_transforms = Compose([
RandomApply([
NoiseInjection(self.params.noise_ratio),
LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio)], p=0.5),
# Utility
NormalizeLocal(), ToTensor()
])
val_transforms = Compose([NormalizeLocal(), ToTensor()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
dataset = Namespace(
**dict(
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, mixup=self.params.mixup,
mel_transforms=mel_transforms, transforms=aug_transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=val_transforms),
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
mel_transforms=mel_transforms, transforms=val_transforms),
)
)
return dataset
class BaseDataloadersMixin(ABC):
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.params.batch_size,
num_workers=self.params.worker)