Dataset rdy

This commit is contained in:
Steffen Illium
2021-02-16 10:18:04 +01:00
parent 151b22a2c3
commit 7edd3834a1
11 changed files with 350 additions and 15 deletions

0
util/__init__.py Normal file
View File

9
util/loss_mixin.py Normal file
View File

@@ -0,0 +1,9 @@
from torch import nn
class LossMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
ce_loss = nn.CrossEntropyLoss()

95
util/module_mixins.py Normal file
View File

@@ -0,0 +1,95 @@
from abc import ABC
import torch
from ml_lib.modules.util import LightningBaseModule
from util.loss_mixin import LossMixin
from util.optimizer_mixin import OptimizerMixin
class TrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
loss = self.ce_loss(y.squeeze(), batch_y.long())
return dict(loss=loss)
def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class ValMixin:
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
model_out = self(batch_x)
y = model_out.main_out
val_loss = self.ce_loss(y.squeeze(), batch_y.long())
return dict(val_loss=val_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict()
keys = list(outputs[0].keys())
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
additional_scores = self.additional_scores(outputs)
summary_dict.update(**additional_scores)
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class TestMixin:
def test_step(self, batch_xy, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
model_out = self(batch_x)
y = model_out.main_out
test_loss = self.ce_loss(y.squeeze(), batch_y.long())
return dict(test_loss=test_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def test_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict()
keys = list(outputs[0].keys())
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
additional_scores = self.additional_scores(outputs)
summary_dict.update(**additional_scores)
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class CombinedModelMixins(LossMixin,
TrainMixin,
ValMixin,
TestMixin,
OptimizerMixin,
LightningBaseModule,
ABC):
pass

45
util/optimizer_mixin.py Normal file
View File

@@ -0,0 +1,45 @@
from collections import defaultdict
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR
from torchcontrib.optim import SWA
from ml_lib.modules.util import LightningBaseModule
class OptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
optimizer_dict = dict(
# 'optimizer':optimizer, # The Optimizer
# 'lr_scheduler': scheduler, # The LR scheduler
frequency=1, # The frequency of the scheduler
interval='epoch', # The unit of the scheduler's step size
# 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
# 'monitor': 'mean_val_loss' # Metric to monitor
)
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg:
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
optimizer_dict.update(optimizer=optimizer)
if self.params.lr_warmup_steps:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
else:
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)
optimizer_dict.update(lr_scheduler=scheduler)
return optimizer_dict
def on_train_end(self):
assert isinstance(self, LightningBaseModule)
for opt in self.trainer.optimizers:
if isinstance(opt, SWA):
opt.swap_swa_sgd()
def on_epoch_end(self):
assert isinstance(self, LightningBaseModule)
if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
opt.state = defaultdict(dict)