from pytorch_lightning import Callback, Trainer, LightningModule


class BestScoresCallback(Callback):

    def __init__(self, *monitors) -> None:
        super().__init__()
        self.monitors = list(*monitors)

        self.best_scores = {monitor: 0.0 for monitor in self.monitors}
        self.best_epoch = {monitor: 0 for monitor in self.monitors}

    def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        epoch = pl_module.current_epoch

        for monitor in self.best_scores.keys():
            current_score = trainer.callback_metrics.get(monitor)
            if current_score is None:
                pass
            else:
                self.best_scores[monitor] = max(self.best_scores[monitor], current_score)
                if self.best_scores[monitor] == current_score:
                    self.best_epoch[monitor] = max(self.best_epoch[monitor], epoch)