from collections import defaultdict from pathlib import Path from abc import ABC import torch import pandas as pd from matplotlib import pyplot as plt from ml_lib.modules.util import LightningBaseModule from util.loss_mixin import LossMixin from util.optimizer_mixin import OptimizerMixin class TrainMixin: def training_step(self, batch_xy, batch_nb, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_files, batch_x, batch_y = batch_xy y = self(batch_x).main_out if self.params.n_classes <= 2: loss = self.bce_loss(y.squeeze().float(), batch_y.float()) else: if self.params.loss == 'focal_loss_rob': labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes) loss = self.__getattribute__(self.params.loss)(y, labels_one_hot) else: loss = self.__getattribute__(self.params.loss)(y, batch_y.long()) return dict(loss=loss) def training_epoch_end(self, outputs): assert isinstance(self, LightningBaseModule) keys = list(outputs[0].keys()) summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key} summary_dict.update(epoch=self.current_epoch) self.log_dict(summary_dict) class ValMixin: def validation_step(self, batch_xy, batch_idx, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_files, batch_x, batch_y = batch_xy model_out = self(batch_x) y = model_out.main_out sorted_y = defaultdict(list) sorted_batch_y = dict() for idx, file_name in enumerate(batch_files): sorted_y[file_name].append(y[idx]) sorted_batch_y.update({file_name: batch_y[idx]}) sorted_y = dict(sorted_y) for file_name in sorted_y: sorted_y.update({file_name: torch.stack(sorted_y[file_name])}) target_y = torch.stack(tuple(sorted_batch_y.values())).long() if self.params.n_classes <= 2: mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x.squeeze(-1) for x in sorted_y.values()]) self.metrics.update(mean_sorted_y, target_y) else: y_max = torch.stack( [torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()] ).squeeze() y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float() self.metrics.update(y_one_hot, target_y) if self.params.n_classes <= 2: val_loss = self.bce_loss(y.squeeze().float(), batch_y.float()) else: val_loss = self.ce_loss(y, batch_y.long()) return dict(batch_files=batch_files, val_loss=val_loss, batch_idx=batch_idx, y=y, batch_y=batch_y) def validation_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) summary_dict = dict() keys = list(outputs[0].keys()) summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key} ) sorted_y = defaultdict(list) sorted_batch_y = dict() for output in outputs: for idx, file_name in enumerate(output['batch_files']): sorted_y[file_name].append(output['y'][idx]) sorted_batch_y.update({file_name: output['batch_y'][idx]}) sorted_y = dict(sorted_y) sorted_batch_y = torch.stack(tuple(sorted_batch_y.values())).long() for file_name in sorted_y: sorted_y.update({file_name: torch.stack(sorted_y[file_name])}) if self.params.n_classes <= 2: mean_sorted_y = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()] mean_sorted_y = torch.stack(mean_sorted_y).squeeze(1) # mean_sorted_y = mean_sorted_y if mean_sorted_y.numel() > 1 else mean_sorted_y.unsqueeze(-1) max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float()) # Sklearn Scores additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y)) else: y_max = torch.stack( [torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()] ).squeeze() y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float() max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y) # Sklearn Scores additional_scores = self.additional_scores(dict(y=y_one_hot, batch_y=sorted_batch_y)) summary_dict.update(val_max_vote_loss=max_vote_loss, **additional_scores) summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key} ) pl_metrics, pl_images = self.metrics.compute_and_prepare() self.metrics.reset() summary_dict.update(**pl_metrics) summary_dict.update(epoch=self.current_epoch) self.log_dict(summary_dict) # For Debugging: # print(f'Summary Metrics are: {summary_dict}') for name, image in pl_images.items(): self.logger.log_image(name, image, step=self.global_step) plt.close(image) pass class TestMixin: def test_step(self, batch_xy, batch_idx, *_, **__): assert isinstance(self, LightningBaseModule) batch_files, batch_x, batch_y = batch_xy model_out = self(batch_x) y = model_out.main_out return dict(batch_files=batch_files, batch_idx=batch_idx, y=y) def test_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) # No logging, just inference. sorted_y = defaultdict(list) for output in outputs: for idx, file_name in enumerate(output['batch_files']): sorted_y[file_name].append(output['y'][idx].cpu()) sorted_y = dict(sorted_y) for file_name in sorted_y: sorted_y.update({file_name: torch.stack(sorted_y[file_name])}) if self.params.n_classes > 2: pred = torch.stack( [torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()] ).squeeze().cpu() class_names = {val: key for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])} else: pred = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()] pred = torch.stack(pred).squeeze() pred = torch.where(pred > 0.5, 1, 0) class_names = {val: key for val, key in enumerate(['negative', 'positive'])} df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()], prediction=[class_names[x.item()] for x in pred.cpu()])) result_file = Path(self.logger.log_dir / 'predictions.csv') if result_file.exists(): try: result_file.unlink() except: print('File already existed') pass with result_file.open(mode='wb') as csv_file: df.to_csv(index=False, path_or_buf=csv_file) if False: with result_file.open(mode='rb') as csv_file: try: self.logger.neptunelogger.log_artifact(csv_file) except: print('No possible to send to neptune') pass class CombinedModelMixins(LossMixin, TrainMixin, ValMixin, TestMixin, OptimizerMixin, LightningBaseModule, ABC): pass