New Model, Many Changes

This commit is contained in:
Si11ium
2020-11-21 09:28:26 +01:00
parent 7bac9e984b
commit be097a111a
12 changed files with 349 additions and 125 deletions

View File

@@ -1,26 +0,0 @@
from ml_lib.utils.config import Config
from models.conv_classifier import ConvClassifier
from models.bandwise_conv_classifier import BandwiseConvClassifier
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
from models.ensemble import Ensemble
from models.residual_conv_classifier import ResidualConvClassifier
from models.transformer_model import VisualTransformer
class MConfig(Config):
# TODO: There should be a way to automate this.
@property
def _model_map(self):
return dict(ConvClassifier=ConvClassifier,
CC=ConvClassifier,
BandwiseConvClassifier=BandwiseConvClassifier,
BCC=BandwiseConvClassifier,
BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier,
BCMC=BandwiseConvMultiheadClassifier,
Ensemble=Ensemble,
E=Ensemble,
ResidualConvClassifier=ResidualConvClassifier,
RCC=ResidualConvClassifier,
ViT=VisualTransformer
)

View File

@@ -8,7 +8,8 @@ import torch
import numpy as np
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose, RandomApply
@@ -25,10 +26,23 @@ class BaseOptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
optimizer_dict = dict(
# 'optimizer':optimizer, # The Optimizer
# 'lr_scheduler': scheduler, # The LR scheduler
frequency=1, # The frequency of the scheduler
interval='epoch', # The unit of the scheduler's step size
# 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
# 'monitor': 'mean_val_loss' # Metric to monitor
)
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
optimizer_dict.update(optimizer=optimizer)
if self.params.lr_warmup_steps:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
optimizer_dict.update(lr_scheduler=scheduler)
return optimizer_dict
def on_train_end(self):
assert isinstance(self, LightningBaseModule)
@@ -54,17 +68,18 @@ class BaseTrainMixin:
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
bce_loss = self.bce_loss(y, batch_y)
bce_loss = self.bce_loss(y.squeeze(), batch_y)
return dict(loss=bce_loss)
def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
return summary_dict
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BaseValMixin:
@@ -77,17 +92,17 @@ class BaseValMixin:
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
val_bce_loss = self.bce_loss(y, batch_y)
val_bce_loss = self.bce_loss(y.squeeze(), batch_y)
return dict(val_bce_loss=val_bce_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *args, **kwargs):
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict())
summary_dict = dict()
for output_idx, output in enumerate(outputs):
keys = list(output[0].keys())
ident = '' if output_idx == 0 else '_train'
summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
for output in output]))
for key in keys if 'loss' in key}
)
@@ -101,8 +116,9 @@ class BaseValMixin:
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
uar_score = torch.as_tensor(uar_score)
summary_dict['log'].update({f'uar{ident}_score': uar_score})
return summary_dict
summary_dict.update({f'uar{ident}_score': uar_score})
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BinaryMaskDatasetMixin:
@@ -139,7 +155,7 @@ class BinaryMaskDatasetMixin:
LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio),
MaskAug(self.params.mask_ratio),
], p=0.6),
], p=0.6),
util_transforms])
# Datasets