Transformer running
This commit is contained in:
67
main.py
67
main.py
@ -1,45 +1,29 @@
|
||||
import configparser
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||
|
||||
from ml_lib.utils.logging import Logger
|
||||
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
||||
from ml_lib.utils.loggers import Logger
|
||||
from ml_lib.utils.tools import locate_and_import_class, auto_cast
|
||||
|
||||
import variables as v
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Argument Parser and default Values
|
||||
# =============================================================================
|
||||
# Load Defaults from _parameters.ini file
|
||||
config = configparser.ConfigParser()
|
||||
config.read('_parameters.ini')
|
||||
project = config['project']
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
data_class = locate_and_import_class(project['data_name'], 'datasets')
|
||||
model_class = locate_and_import_class(project['model_name'], 'models')
|
||||
|
||||
tmp_params = dict()
|
||||
for key in ['project', 'train', 'data', 'model_cnn']:
|
||||
defaults = config[key]
|
||||
tmp_params.update({key: auto_cast(val) for key, val in defaults.items()})
|
||||
|
||||
# Parse Command Line
|
||||
parser = ArgumentParser()
|
||||
for module in [Logger, Trainer, data_class, model_class]:
|
||||
parser = module.add_argparse_args(parser)
|
||||
cmd_args, _ = parser.parse_known_args()
|
||||
tmp_params.update({key: val for key, val in vars(cmd_args).items() if val is not None})
|
||||
hparams = Namespace(**tmp_params)
|
||||
|
||||
with Logger.from_argparse_args(hparams) as logger:
|
||||
def run_lightning_loop(h_params, data_class, model_class):
|
||||
with Logger.from_argparse_args(h_params) as logger:
|
||||
# Callbacks
|
||||
# =============================================================================
|
||||
# Checkpoint Saving
|
||||
ckpt_callback = ModelCheckpoint(
|
||||
monitor='mean_loss',
|
||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||
dirpath=str(logger.log_dir),
|
||||
filename='ckpt_weights',
|
||||
verbose=False,
|
||||
save_top_k=3,
|
||||
)
|
||||
@ -47,22 +31,37 @@ if __name__ == '__main__':
|
||||
# Learning Rate Logger
|
||||
lr_logger = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
#
|
||||
# START
|
||||
# =============================================================================
|
||||
# Let Datamodule pull what it wants
|
||||
datamodule = data_class.from_argparse_args(hparams)
|
||||
datamodule = data_class.from_argparse_args(h_params)
|
||||
datamodule.setup()
|
||||
model_in_shape = datamodule.shape
|
||||
|
||||
# Let Trainer pull what it wants and add callbacks
|
||||
trainer = Trainer.from_argparse_args(hparams, callbacks=[ckpt_callback, lr_logger])
|
||||
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=[ckpt_callback, lr_logger])
|
||||
|
||||
# Let Model pull what it wants
|
||||
model = model_class.from_argparse_args(hparams, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi)
|
||||
model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=v.N_CLASS_multi)
|
||||
model.init_weights()
|
||||
|
||||
logger.log_hyperparams(dict(model.params))
|
||||
trainer.fit(model, datamodule)
|
||||
|
||||
trainer.save_checkpoint(trainer.logger.save_dir)
|
||||
# Log paramters
|
||||
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
||||
# logger.log_text('n_parameters', pytorch_total_params)
|
||||
|
||||
trainer.save_checkpoint(logger.save_dir / 'weights.ckpt')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Parse comandline args, read config and get model
|
||||
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults('_parameters.ini')
|
||||
|
||||
# To NameSpace
|
||||
hparams = Namespace(**cmd_args)
|
||||
|
||||
# Start
|
||||
# -----------------
|
||||
run_lightning_loop(hparams, found_data_class, found_model_class)
|
||||
print('done')
|
||||
pass
|
||||
|
Reference in New Issue
Block a user