2021-04-02 08:45:11 +02:00

99 lines
3.3 KiB
Python

from argparse import Namespace
import warnings
import yaml
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from ml_lib.utils.callbacks import BestScoresCallback
from ml_lib.utils.config import parse_comandline_args_add_defaults
from ml_lib.utils.loggers import Logger
import variables as v
from ml_lib.utils.tools import fix_all_random_seeds
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, additional_callbacks=None):
fix_all_random_seeds(seed)
with Logger.from_argparse_args(h_params) as logger:
# Callbacks
# =============================================================================
# Checkpoint Saving
ckpt_callback = ModelCheckpoint(
monitor='PL_recall_score',
dirpath=str(logger.log_dir),
filename='ckpt_weights',
mode='max',
verbose=False,
save_top_k=3,
save_last=True
)
# Learning Rate Logger
lr_logger = LearningRateMonitor(logging_interval='epoch')
# Track best scores
score_callback = BestScoresCallback(['PL_recall_score'])
callbacks = [ckpt_callback, lr_logger, score_callback]
if additional_callbacks and isinstance(additional_callbacks, Callback):
callbacks.append(additional_callbacks)
elif additional_callbacks and isinstance(additional_callbacks, list):
callbacks.extend(additional_callbacks)
else:
pass
# START
# =============================================================================
# Let Datamodule pull what it wants
datamodule = data_class.from_argparse_args(h_params)
# Final h_params Setup:
h_params = vars(h_params)
h_params.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes)
h_params = Namespace(**h_params)
# Let Trainer pull what it wants and add callbacks
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=callbacks)
# Let Model pull what it wants
model = model_class.from_argparse_args(h_params)
model.init_weights()
# Store Model in Object File:
model.save_to_disk(logger.save_dir)
# Store h_params to yaml_file File & Neptune (if available):
logger.log_hyperparams(h_params)
trainer.fit(model, datamodule)
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
trainer.test(model=model, datamodule=datamodule, ckpt_path='best')
return Namespace(model=model,
best_model_path=ckpt_callback.best_model_path,
best_model_score=ckpt_callback.best_model_score.item(),
max_score_monitor=score_callback.best_scores)
if __name__ == '__main__':
# Parse comandline args, read config and get model
cmd_args, found_data_class, found_model_class, found_seed = parse_comandline_args_add_defaults('_parameters.ini')
# To NameSpace
hparams = Namespace(**cmd_args)
# Start
# -----------------
print(run_lightning_loop(hparams, found_data_class, found_model_class, found_seed))
print('done')
pass