2021-04-02 08:45:11 +02:00

204 lines
8.2 KiB
Python

from collections import defaultdict
from pathlib import Path
from abc import ABC
import torch
import pandas as pd
from matplotlib import pyplot as plt
from ml_lib.modules.util import LightningBaseModule
from util.loss_mixin import LossMixin
from util.optimizer_mixin import OptimizerMixin
class TrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_files, batch_x, batch_y = batch_xy
y = self(batch_x).main_out
if self.params.n_classes <= 2:
loss = self.bce_loss(y.squeeze().float(), batch_y.float())
else:
if self.params.loss == 'focal_loss_rob':
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
else:
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
return dict(loss=loss)
def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
summary_dict.update(epoch=self.current_epoch)
self.log_dict(summary_dict)
class ValMixin:
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_files, batch_x, batch_y = batch_xy
model_out = self(batch_x)
y = model_out.main_out
sorted_y = defaultdict(list)
sorted_batch_y = dict()
for idx, file_name in enumerate(batch_files):
sorted_y[file_name].append(y[idx])
sorted_batch_y.update({file_name: batch_y[idx]})
sorted_y = dict(sorted_y)
for file_name in sorted_y:
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
if self.params.n_classes <= 2:
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x.squeeze(-1) for x in sorted_y.values()])
self.metrics.update(mean_sorted_y, target_y)
else:
y_max = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
self.metrics.update(y_one_hot, target_y)
if self.params.n_classes <= 2:
val_loss = self.bce_loss(y.squeeze().float(), batch_y.float())
else:
val_loss = self.ce_loss(y, batch_y.long())
return dict(batch_files=batch_files, val_loss=val_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict()
keys = list(outputs[0].keys())
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
sorted_y = defaultdict(list)
sorted_batch_y = dict()
for output in outputs:
for idx, file_name in enumerate(output['batch_files']):
sorted_y[file_name].append(output['y'][idx])
sorted_batch_y.update({file_name: output['batch_y'][idx]})
sorted_y = dict(sorted_y)
sorted_batch_y = torch.stack(tuple(sorted_batch_y.values())).long()
for file_name in sorted_y:
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
if self.params.n_classes <= 2:
mean_sorted_y = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
mean_sorted_y = torch.stack(mean_sorted_y).squeeze(1)
# mean_sorted_y = mean_sorted_y if mean_sorted_y.numel() > 1 else mean_sorted_y.unsqueeze(-1)
max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float())
# Sklearn Scores
additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y))
else:
y_max = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze()
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
# Sklearn Scores
additional_scores = self.additional_scores(dict(y=y_one_hot, batch_y=sorted_batch_y))
summary_dict.update(val_max_vote_loss=max_vote_loss, **additional_scores)
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
pl_metrics, pl_images = self.metrics.compute_and_prepare()
self.metrics.reset()
summary_dict.update(**pl_metrics)
summary_dict.update(epoch=self.current_epoch)
self.log_dict(summary_dict)
# For Debugging:
# print(f'Summary Metrics are: {summary_dict}')
for name, image in pl_images.items():
self.logger.log_image(name, image, step=self.global_step)
plt.close(image)
pass
class TestMixin:
def test_step(self, batch_xy, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule)
batch_files, batch_x, batch_y = batch_xy
model_out = self(batch_x)
y = model_out.main_out
return dict(batch_files=batch_files, batch_idx=batch_idx, y=y)
def test_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
# No logging, just inference.
sorted_y = defaultdict(list)
for output in outputs:
for idx, file_name in enumerate(output['batch_files']):
sorted_y[file_name].append(output['y'][idx].cpu())
sorted_y = dict(sorted_y)
for file_name in sorted_y:
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
if self.params.n_classes > 2:
pred = torch.stack(
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
).squeeze().cpu()
class_names = {val: key for val, key in
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
else:
pred = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
pred = torch.stack(pred).squeeze()
pred = torch.where(pred > 0.5, 1, 0)
class_names = {val: key for val, key in enumerate(['negative', 'positive'])}
df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],
prediction=[class_names[x.item()] for x in pred.cpu()]))
result_file = Path(self.logger.log_dir / 'predictions.csv')
if result_file.exists():
try:
result_file.unlink()
except:
print('File already existed')
pass
with result_file.open(mode='wb') as csv_file:
df.to_csv(index=False, path_or_buf=csv_file)
if False:
with result_file.open(mode='rb') as csv_file:
try:
self.logger.neptunelogger.log_artifact(csv_file)
except:
print('No possible to send to neptune')
pass
class CombinedModelMixins(LossMixin,
TrainMixin,
ValMixin,
TestMixin,
OptimizerMixin,
LightningBaseModule,
ABC):
pass