ResidualModule and New Parameters, Speed Manipulation

This commit is contained in:
Si11ium
2020-05-12 12:37:26 +02:00
parent 3fbc98dfa3
commit 28bfcfdce3
8 changed files with 181 additions and 78 deletions

View File

@ -3,6 +3,7 @@ from models.conv_classifier import ConvClassifier
from models.bandwise_conv_classifier import BandwiseConvClassifier
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
from models.ensemble import Ensemble
from models.residual_conv_classifier import ResidualConvClassifier
class MConfig(Config):
@ -11,7 +12,13 @@ class MConfig(Config):
@property
def _model_map(self):
return dict(ConvClassifier=ConvClassifier,
CC=ConvClassifier,
BandwiseConvClassifier=BandwiseConvClassifier,
BCC=BandwiseConvClassifier,
BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier,
BCMC=BandwiseConvMultiheadClassifier,
Ensemble=Ensemble,
E=Ensemble,
ResidualConvClassifier=ResidualConvClassifier,
RCC=ResidualConvClassifier
)

View File

@ -6,13 +6,14 @@ from argparse import Namespace
import sklearn
import torch
import numpy as np
from torch.nn import L1Loss
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, RandomSampler
from torchcontrib.optim import SWA
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime
from ml_lib.audio_toolset.audio_augmentation import Speed
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
from ml_lib.modules.utils import LightningBaseModule
from ml_lib.utils.transforms import ToTensor
@ -24,17 +25,19 @@ class BaseOptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=0.04)
if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt
def on_train_end(self):
assert isinstance(self, LightningBaseModule)
for opt in self.trainer.optimizers:
if isinstance(opt, SWA):
opt.swap_swa_sgd()
def on_epoch_end(self):
assert isinstance(self, LightningBaseModule)
if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
@ -43,14 +46,19 @@ class BaseOptimizerMixin:
class BaseTrainMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
loss = self.criterion(y, batch_y)
return dict(loss=loss)
bce_loss = self.bce_loss(y, batch_y)
return dict(loss=bce_loss)
def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
@ -61,18 +69,20 @@ class BaseTrainMixin:
class BaseValMixin:
absolute_loss = L1Loss()
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
val_bce_loss = self.criterion(y, batch_y)
val_abs_loss = self.absolute_loss(y, batch_y)
return dict(val_bce_loss=val_bce_loss, val_abs_loss=val_abs_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y
)
val_bce_loss = self.bce_loss(y, batch_y)
return dict(val_bce_loss=val_bce_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict())
for output_idx, output in enumerate(outputs):
keys = list(output[0].keys())
@ -103,6 +113,12 @@ class BinaryMaskDatasetFunction:
# Dataset
# =============================================================================
# Mel Transforms
mel_transforms_train = Compose([
# Audio to Mel Transformations
Speed(speed_factor=self.params.speed_factor, max_ratio=self.params.speed_ratio),
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
hop_length=self.params.hop_length),
MelToImage()])
mel_transforms = Compose([
# Audio to Mel Transformations
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
@ -112,25 +128,28 @@ class BinaryMaskDatasetFunction:
RandomApply([
NoiseInjection(self.params.noise_ratio),
LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio)], p=0.5),
ShiftTime(self.params.shift_ratio),
MaskAug(self.params.mask_ratio),
], p=0.6),
# Utility
NormalizeLocal(), ToTensor()
])
val_transforms = Compose([NormalizeLocal(), ToTensor()])
# sampler = RandomSampler(train_dataset, True, len(train_dataset)) if params['bootstrap'] else None
# Datasets
from datasets.binar_masks import BinaryMasksDataset
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
mixup=self.params.mixup,
mel_transforms=mel_transforms, transforms=aug_transforms),
mel_transforms=mel_transforms_train, transforms=aug_transforms),
# VALIDATION DATASET
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
mel_transforms=mel_transforms, transforms=val_transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=val_transforms),
# TEST DATASET
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
mel_transforms=mel_transforms, transforms=val_transforms),
)
@ -144,18 +163,23 @@ class BaseDataloadersMixin(ABC):
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
assert isinstance(self, LightningBaseModule)
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
sampler = None
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Test Dataloader
def test_dataloader(self):
assert isinstance(self, LightningBaseModule)
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Validation Dataloader
def val_dataloader(self):
assert isinstance(self, LightningBaseModule)
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.params.batch_size, num_workers=self.params.worker)