Final Train Runs
This commit is contained in:
23
utils/callbacks.py
Normal file
23
utils/callbacks.py
Normal file
@ -0,0 +1,23 @@
|
||||
from pytorch_lightning import Callback, Trainer, LightningModule
|
||||
|
||||
|
||||
class BestScoresCallback(Callback):
|
||||
|
||||
def __init__(self, *monitors) -> None:
|
||||
super().__init__()
|
||||
self.monitors = list(*monitors)
|
||||
|
||||
self.best_scores = {monitor: 0.0 for monitor in self.monitors}
|
||||
self.best_epoch = {monitor: 0 for monitor in self.monitors}
|
||||
|
||||
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
||||
epoch = pl_module.current_epoch
|
||||
|
||||
for monitor in self.best_scores.keys():
|
||||
current_score = trainer.callback_metrics.get(monitor)
|
||||
if current_score is None:
|
||||
pass
|
||||
else:
|
||||
self.best_scores[monitor] = max(self.best_scores[monitor], current_score)
|
||||
if self.best_scores[monitor] == current_score:
|
||||
self.best_epoch[monitor] = max(self.best_epoch[monitor], epoch)
|
Reference in New Issue
Block a user