Transformer running

This commit is contained in:
Steffen Illium
2021-03-04 12:01:09 +01:00
parent 7edd3834a1
commit ad254dae92
14 changed files with 679 additions and 134 deletions

View File

@@ -1,9 +1,13 @@
from torch import nn
from ml_lib.additions.losses import FocalLoss, FocalLossRob
class LossMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
ce_loss = nn.CrossEntropyLoss()
ce_loss = nn.CrossEntropyLoss()
focal_loss = FocalLoss(None)
focal_loss_rob = FocalLossRob()

View File

@@ -1,6 +1,7 @@
from abc import ABC
import torch
import pandas as pd
from ml_lib.modules.util import LightningBaseModule
from util.loss_mixin import LossMixin
@@ -11,9 +12,15 @@ class TrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
batch_files, batch_x, batch_y = batch_xy
y = self(batch_x).main_out
loss = self.ce_loss(y.squeeze(), batch_y.long())
if self.params.loss == 'focal_loss_rob':
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=5)
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
else:
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
return dict(loss=loss)
def training_epoch_end(self, outputs):
@@ -23,21 +30,23 @@ class TrainMixin:
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
for key in summary_dict.keys():
self.log(key, summary_dict[key])
summary_dict.update(epoch=self.current_epoch)
self.log_dict(summary_dict)
class ValMixin:
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
batch_files, batch_x, batch_y = batch_xy
model_out = self(batch_x)
y = model_out.main_out
val_loss = self.ce_loss(y.squeeze(), batch_y.long())
val_loss = self.ce_loss(y, batch_y.long())
return dict(val_loss=val_loss,
self.metrics.update(y, batch_y) # torch.argmax(y, -1), batch_y)
return dict(val_loss=val_loss, batch_files=batch_files,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *_, **__):
@@ -49,40 +58,40 @@ class ValMixin:
for output in outputs]))
for key in keys if 'loss' in key}
)
# Sklearn Scores
additional_scores = self.additional_scores(outputs)
summary_dict.update(**additional_scores)
for key in summary_dict.keys():
self.log(key, summary_dict[key])
pl_metrics, pl_images = self.metrics.compute_and_prepare()
self.metrics.reset()
summary_dict.update(**pl_metrics)
summary_dict.update(epoch=self.current_epoch)
self.log_dict(summary_dict, on_epoch=True)
for name, image in pl_images.items():
self.logger.log_image(name, image, step=self.global_step)
pass
class TestMixin:
def test_step(self, batch_xy, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
batch_files, batch_x, batch_y = batch_xy
model_out = self(batch_x)
y = model_out.main_out
test_loss = self.ce_loss(y.squeeze(), batch_y.long())
return dict(test_loss=test_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
return dict(batch_files=batch_files, batch_idx=batch_idx, y=y)
def test_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict()
keys = list(outputs[0].keys())
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
y_arg_max = torch.argmax(outputs[0]['y'])
additional_scores = self.additional_scores(outputs)
summary_dict.update(**additional_scores)
pd.DataFrame(data=dict(filenames=outputs[0]['batch_files'], predtiction=y_arg_max))
for key in summary_dict.keys():
self.log(key, summary_dict[key])
# No logging, just inference.
# self.log_dict(summary_dict, on_epoch=True)
class CombinedModelMixins(LossMixin,

View File

@@ -1,6 +1,6 @@
from collections import defaultdict
from torch.optim import Adam
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR
from torchcontrib.optim import SWA
@@ -20,12 +20,12 @@ class OptimizerMixin:
# 'monitor': 'mean_val_loss' # Metric to monitor
)
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
optimizer = AdamW(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg:
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
optimizer_dict.update(optimizer=optimizer)
if self.params.lr_warmup_steps:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
if self.params.lr_warm_restart_epochs:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warm_restart_epochs)
else:
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)
optimizer_dict.update(lr_scheduler=scheduler)
@@ -42,4 +42,4 @@ class OptimizerMixin:
if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
opt.state = defaultdict(dict)
opt.state = defaultdict(dict)