Final Train Runs
This commit is contained in:
44
main.py
44
main.py
@ -2,12 +2,12 @@ from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import Trainer, Callback
|
||||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||
|
||||
from ml_lib.utils.callbacks import BestScoresCallback
|
||||
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
||||
from ml_lib.utils.loggers import Logger
|
||||
from ml_lib.utils.tools import locate_and_import_class, auto_cast
|
||||
|
||||
import variables as v
|
||||
|
||||
@ -15,22 +15,38 @@ warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
|
||||
def run_lightning_loop(h_params, data_class, model_class):
|
||||
def run_lightning_loop(h_params, data_class, model_class, additional_callbacks=None):
|
||||
with Logger.from_argparse_args(h_params) as logger:
|
||||
# Callbacks
|
||||
# =============================================================================
|
||||
# Checkpoint Saving
|
||||
ckpt_callback = ModelCheckpoint(
|
||||
monitor='mean_loss',
|
||||
monitor='PL_recall_score',
|
||||
dirpath=str(logger.log_dir),
|
||||
filename='ckpt_weights',
|
||||
mode='max',
|
||||
verbose=False,
|
||||
save_top_k=3,
|
||||
save_last=True
|
||||
)
|
||||
|
||||
# Learning Rate Logger
|
||||
lr_logger = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
|
||||
|
||||
# Track best scores
|
||||
score_callback = BestScoresCallback(['PL_recall_score'])
|
||||
|
||||
callbacks = [ckpt_callback, lr_logger, score_callback]
|
||||
|
||||
if additional_callbacks and isinstance(additional_callbacks, Callback):
|
||||
callbacks.append(additional_callbacks)
|
||||
elif additional_callbacks and isinstance(additional_callbacks, list):
|
||||
callbacks.extend(additional_callbacks)
|
||||
else:
|
||||
pass
|
||||
|
||||
# START
|
||||
# =============================================================================
|
||||
# Let Datamodule pull what it wants
|
||||
@ -38,19 +54,27 @@ def run_lightning_loop(h_params, data_class, model_class):
|
||||
datamodule.setup()
|
||||
|
||||
# Let Trainer pull what it wants and add callbacks
|
||||
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=[ckpt_callback, lr_logger])
|
||||
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=callbacks)
|
||||
|
||||
# Let Model pull what it wants
|
||||
model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi)
|
||||
model.init_weights()
|
||||
|
||||
# trainer.test(model=model, datamodule=datamodule)
|
||||
|
||||
trainer.fit(model, datamodule)
|
||||
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
|
||||
|
||||
# Log paramters
|
||||
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
||||
# logger.log_text('n_parameters', pytorch_total_params)
|
||||
|
||||
trainer.save_checkpoint(logger.save_dir / 'weights.ckpt')
|
||||
try:
|
||||
trainer.test(model=model, datamodule=datamodule)
|
||||
except:
|
||||
print('Test did not Suceed!')
|
||||
pass
|
||||
try:
|
||||
logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1)
|
||||
except:
|
||||
print('debug max_score_logging')
|
||||
return score_callback.best_scores['PL_recall_score']
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user