New Model, Many Changes
This commit is contained in:
@@ -1,26 +0,0 @@
|
||||
from ml_lib.utils.config import Config
|
||||
from models.conv_classifier import ConvClassifier
|
||||
from models.bandwise_conv_classifier import BandwiseConvClassifier
|
||||
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
|
||||
from models.ensemble import Ensemble
|
||||
from models.residual_conv_classifier import ResidualConvClassifier
|
||||
from models.transformer_model import VisualTransformer
|
||||
|
||||
|
||||
class MConfig(Config):
|
||||
# TODO: There should be a way to automate this.
|
||||
|
||||
@property
|
||||
def _model_map(self):
|
||||
return dict(ConvClassifier=ConvClassifier,
|
||||
CC=ConvClassifier,
|
||||
BandwiseConvClassifier=BandwiseConvClassifier,
|
||||
BCC=BandwiseConvClassifier,
|
||||
BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier,
|
||||
BCMC=BandwiseConvMultiheadClassifier,
|
||||
Ensemble=Ensemble,
|
||||
E=Ensemble,
|
||||
ResidualConvClassifier=ResidualConvClassifier,
|
||||
RCC=ResidualConvClassifier,
|
||||
ViT=VisualTransformer
|
||||
)
|
@@ -8,7 +8,8 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
from torch.utils.data import DataLoader
|
||||
from torchcontrib.optim import SWA
|
||||
from torchvision.transforms import Compose, RandomApply
|
||||
|
||||
@@ -25,10 +26,23 @@ class BaseOptimizerMixin:
|
||||
|
||||
def configure_optimizers(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
optimizer_dict = dict(
|
||||
# 'optimizer':optimizer, # The Optimizer
|
||||
# 'lr_scheduler': scheduler, # The LR scheduler
|
||||
frequency=1, # The frequency of the scheduler
|
||||
interval='epoch', # The unit of the scheduler's step size
|
||||
# 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
|
||||
# 'monitor': 'mean_val_loss' # Metric to monitor
|
||||
)
|
||||
|
||||
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
if self.params.sto_weight_avg:
|
||||
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||
return opt
|
||||
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||
optimizer_dict.update(optimizer=optimizer)
|
||||
if self.params.lr_warmup_steps:
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
|
||||
optimizer_dict.update(lr_scheduler=scheduler)
|
||||
return optimizer_dict
|
||||
|
||||
def on_train_end(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
@@ -54,17 +68,18 @@ class BaseTrainMixin:
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
batch_x, batch_y = batch_xy
|
||||
y = self(batch_x).main_out
|
||||
bce_loss = self.bce_loss(y, batch_y)
|
||||
bce_loss = self.bce_loss(y.squeeze(), batch_y)
|
||||
return dict(loss=bce_loss)
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
keys = list(outputs[0].keys())
|
||||
|
||||
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key})
|
||||
return summary_dict
|
||||
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key}
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
|
||||
|
||||
class BaseValMixin:
|
||||
@@ -77,17 +92,17 @@ class BaseValMixin:
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
batch_x, batch_y = batch_xy
|
||||
y = self(batch_x).main_out
|
||||
val_bce_loss = self.bce_loss(y, batch_y)
|
||||
val_bce_loss = self.bce_loss(y.squeeze(), batch_y)
|
||||
return dict(val_bce_loss=val_bce_loss,
|
||||
batch_idx=batch_idx, y=y, batch_y=batch_y)
|
||||
|
||||
def validation_epoch_end(self, outputs, *args, **kwargs):
|
||||
def validation_epoch_end(self, outputs, *_, **__):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
summary_dict = dict(log=dict())
|
||||
summary_dict = dict()
|
||||
for output_idx, output in enumerate(outputs):
|
||||
keys = list(output[0].keys())
|
||||
ident = '' if output_idx == 0 else '_train'
|
||||
summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
|
||||
summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in output]))
|
||||
for key in keys if 'loss' in key}
|
||||
)
|
||||
@@ -101,8 +116,9 @@ class BaseValMixin:
|
||||
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
|
||||
sample_weight=None, zero_division='warn')
|
||||
uar_score = torch.as_tensor(uar_score)
|
||||
summary_dict['log'].update({f'uar{ident}_score': uar_score})
|
||||
return summary_dict
|
||||
summary_dict.update({f'uar{ident}_score': uar_score})
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
|
||||
|
||||
class BinaryMaskDatasetMixin:
|
||||
@@ -139,7 +155,7 @@ class BinaryMaskDatasetMixin:
|
||||
LoudnessManipulator(self.params.loudness_ratio),
|
||||
ShiftTime(self.params.shift_ratio),
|
||||
MaskAug(self.params.mask_ratio),
|
||||
], p=0.6),
|
||||
], p=0.6),
|
||||
util_transforms])
|
||||
|
||||
# Datasets
|
||||
|
Reference in New Issue
Block a user