2021-03-22 16:43:19 +01:00

202 lines
7.9 KiB
Python

from collections import defaultdict
from pathlib import Path
from abc import ABC
import torch
import pandas as pd
from ml_lib.modules.util import LightningBaseModule
from util.loss_mixin import LossMixin
from util.optimizer_mixin import OptimizerMixin
class TrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_files, batch_x, batch_y = batch_xy
y = self(batch_x).main_out
if self.params.n_classes <= 2:
loss = self.bce_loss(y, batch_y.long())
else:
if self.params.loss == 'focal_loss_rob':
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
else:
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
return dict(loss=loss)
def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
summary_dict.update(epoch=self.current_epoch)
self.log_dict(summary_dict)
class ValMixin:
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_files, batch_x, batch_y = batch_xy
model_out = self(batch_x)
y = model_out.main_out
sorted_y = defaultdict(list)
sorted_batch_y = dict()
for idx, file_name in enumerate(batch_files):
sorted_y[file_name].append(y[idx])
sorted_batch_y.update({file_name: batch_y[idx]})
sorted_y = dict(sorted_y)
for file_name in sorted_y:
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
y_max = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
if self.params.n_classes <= 2:
if y_one_hot.ndim == 1:
y_one_hot = y_one_hot.unsqueeze(0)
if target_y.ndim == 1:
target_y = target_y.unsqueeze(-1)
self.metrics.update(y_one_hot, target_y)
if self.params.n_classes <= 2:
val_loss = self.bce_loss(y.squeeze().float(), batch_y.float())
else:
val_loss = self.ce_loss(y, batch_y.long())
return dict(batch_files=batch_files, val_loss=val_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict()
keys = list(outputs[0].keys())
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
sorted_y = defaultdict(list)
sorted_batch_y = dict()
for output in outputs:
for idx, file_name in enumerate(output['batch_files']):
sorted_y[file_name].append(output['y'][idx])
sorted_batch_y.update({file_name: output['batch_y'][idx]})
sorted_y = dict(sorted_y)
sorted_batch_y = torch.stack(tuple(sorted_batch_y.values())).long()
for file_name in sorted_y:
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
#y_mean = torch.stack(
# [torch.mean(x, dim=0, keepdim=True) if x.shape[0] > 1 else x for x in sorted_y.values()]
#).squeeze()
#if y_mean.ndim == 1:
# y_mean = y_mean.unsqueeze(0)
#if sorted_batch_y.ndim == 1:
# sorted_batch_y = sorted_batch_y.unsqueeze(-1)
#
#mean_vote_loss = self.ce_loss(y_mean, sorted_batch_y)
#summary_dict.update(val_mean_vote_loss=mean_vote_loss)
y_max = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
if self.params.n_classes >= 2:
max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
else:
max_vote_loss = self.bce_loss(y_one_hot, sorted_batch_y)
summary_dict.update(val_max_vote_loss=max_vote_loss)
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
# Sklearn Scores
additional_scores = self.additional_scores(dict(y=y_one_hot, batch_y=sorted_batch_y))
summary_dict.update(**additional_scores)
pl_metrics, pl_images = self.metrics.compute_and_prepare()
self.metrics.reset()
summary_dict.update(**pl_metrics)
summary_dict.update(epoch=self.current_epoch)
self.log_dict(summary_dict, on_epoch=True)
for name, image in pl_images.items():
self.logger.log_image(name, image, step=self.global_step)
pass
class TestMixin:
def test_step(self, batch_xy, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule)
batch_files, batch_x, batch_y = batch_xy
model_out = self(batch_x)
y = model_out.main_out
return dict(batch_files=batch_files, batch_idx=batch_idx, y=y)
def test_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
# No logging, just inference.
sorted_y = defaultdict(list)
for output in outputs:
for idx, file_name in enumerate(output['batch_files']):
sorted_y[file_name].append(output['y'][idx].cpu())
sorted_y = dict(sorted_y)
for file_name in sorted_y:
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
y_max = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze().cpu()
if self.params.n_classes == 5:
class_names = {val: key for val, key in
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
elif self.params.n_classes == 2:
class_names = {val: key for val, key in ['negative', 'positive']}
else:
raise AttributeError(f'n_classes has to be any of: [2, 5]')
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
prediction=[class_names[x.item()] for x in y_max.cpu()]))
result_file = Path(self.logger.log_dir / 'predictions.csv')
if result_file.exists():
try:
result_file.unlink()
except:
print('File already existed')
pass
with result_file.open(mode='wb') as csv_file:
df.to_csv(index=False, path_or_buf=csv_file)
with result_file.open(mode='rb') as csv_file:
try:
self.logger.neptunelogger.log_artifact(csv_file)
except:
print('No possible to send to neptune')
pass
class CombinedModelMixins(LossMixin,
TrainMixin,
ValMixin,
TestMixin,
OptimizerMixin,
LightningBaseModule,
ABC):
pass