from abc import ABC import torch from ml_lib.modules.util import LightningBaseModule from util.loss_mixin import LossMixin from util.optimizer_mixin import OptimizerMixin class TrainMixin: def training_step(self, batch_xy, batch_nb, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out loss = self.ce_loss(y.squeeze(), batch_y.long()) return dict(loss=loss) def training_epoch_end(self, outputs): assert isinstance(self, LightningBaseModule) keys = list(outputs[0].keys()) summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key} for key in summary_dict.keys(): self.log(key, summary_dict[key]) class ValMixin: def validation_step(self, batch_xy, batch_idx, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy model_out = self(batch_x) y = model_out.main_out val_loss = self.ce_loss(y.squeeze(), batch_y.long()) return dict(val_loss=val_loss, batch_idx=batch_idx, y=y, batch_y=batch_y) def validation_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) summary_dict = dict() keys = list(outputs[0].keys()) summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key} ) additional_scores = self.additional_scores(outputs) summary_dict.update(**additional_scores) for key in summary_dict.keys(): self.log(key, summary_dict[key]) class TestMixin: def test_step(self, batch_xy, batch_idx, *_, **__): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy model_out = self(batch_x) y = model_out.main_out test_loss = self.ce_loss(y.squeeze(), batch_y.long()) return dict(test_loss=test_loss, batch_idx=batch_idx, y=y, batch_y=batch_y) def test_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) summary_dict = dict() keys = list(outputs[0].keys()) summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key} ) additional_scores = self.additional_scores(outputs) summary_dict.update(**additional_scores) for key in summary_dict.keys(): self.log(key, summary_dict[key]) class CombinedModelMixins(LossMixin, TrainMixin, ValMixin, TestMixin, OptimizerMixin, LightningBaseModule, ABC): pass