68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
from argparse import Namespace
|
|
|
|
import warnings
|
|
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
|
|
|
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
|
from ml_lib.utils.loggers import Logger
|
|
from ml_lib.utils.tools import locate_and_import_class, auto_cast
|
|
|
|
import variables as v
|
|
|
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
|
warnings.filterwarnings('ignore', category=UserWarning)
|
|
|
|
|
|
def run_lightning_loop(h_params, data_class, model_class):
|
|
with Logger.from_argparse_args(h_params) as logger:
|
|
# Callbacks
|
|
# =============================================================================
|
|
# Checkpoint Saving
|
|
ckpt_callback = ModelCheckpoint(
|
|
monitor='mean_loss',
|
|
dirpath=str(logger.log_dir),
|
|
filename='ckpt_weights',
|
|
verbose=False,
|
|
save_top_k=3,
|
|
)
|
|
|
|
# Learning Rate Logger
|
|
lr_logger = LearningRateMonitor(logging_interval='epoch')
|
|
|
|
# START
|
|
# =============================================================================
|
|
# Let Datamodule pull what it wants
|
|
datamodule = data_class.from_argparse_args(h_params)
|
|
datamodule.setup()
|
|
|
|
# Let Trainer pull what it wants and add callbacks
|
|
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=[ckpt_callback, lr_logger])
|
|
|
|
# Let Model pull what it wants
|
|
model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi)
|
|
model.init_weights()
|
|
|
|
trainer.fit(model, datamodule)
|
|
|
|
# Log paramters
|
|
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
|
# logger.log_text('n_parameters', pytorch_total_params)
|
|
|
|
trainer.save_checkpoint(logger.save_dir / 'weights.ckpt')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Parse comandline args, read config and get model
|
|
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults('_parameters.ini')
|
|
|
|
# To NameSpace
|
|
hparams = Namespace(**cmd_args)
|
|
|
|
# Start
|
|
# -----------------
|
|
run_lightning_loop(hparams, found_data_class, found_model_class)
|
|
print('done')
|
|
pass
|