New Model, Many Changes
This commit is contained in:
36
main.py
36
main.py
@ -6,14 +6,13 @@ import warnings
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
||||
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.config import Config
|
||||
from ml_lib.utils.logging import Logger
|
||||
|
||||
# Project Specific Logger SubClasses
|
||||
from util.config import MConfig
|
||||
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
@ -37,35 +36,30 @@ def run_lightning_loop(config_obj):
|
||||
# Callbacks
|
||||
# =============================================================================
|
||||
# Checkpoint Saving
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor='uar_score',
|
||||
ckpt_callback = ModelCheckpoint(
|
||||
monitor='mean_loss',
|
||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||
verbose=False,
|
||||
save_top_k=5,
|
||||
)
|
||||
|
||||
# Early Stopping
|
||||
# TODO: For This to work, set a validation step and End Eval and Score
|
||||
early_stopping_callback = EarlyStopping(
|
||||
monitor='uar_score',
|
||||
min_delta=0.01,
|
||||
patience=10,
|
||||
)
|
||||
# Learning Rate Logger
|
||||
lr_logger = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
trainer = Trainer(max_epochs=config_obj.train.epochs,
|
||||
show_progress_bar=True,
|
||||
weights_save_path=logger.log_dir,
|
||||
gpus=[0] if torch.cuda.is_available() else None,
|
||||
check_val_every_n_epoch=10,
|
||||
# num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
|
||||
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
||||
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
checkpoint_callback=True,
|
||||
callbacks=[lr_logger, ckpt_callback],
|
||||
logger=logger,
|
||||
fast_dev_run=config_obj.main.debug,
|
||||
early_stop_callback=None
|
||||
auto_lr_find=not config_obj.main.debug
|
||||
)
|
||||
|
||||
# Model
|
||||
@ -78,10 +72,15 @@ def run_lightning_loop(config_obj):
|
||||
|
||||
# Train It
|
||||
if config_obj.model.type.lower() != 'ensemble':
|
||||
if not config_obj.main.debug and not config_obj.train.lr:
|
||||
trainer.tune(model)
|
||||
# ToDo: LR Finder Plot
|
||||
# fig = lr_finder.plot(suggest=True)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
# Save the last state & all parameters
|
||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||
trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt'))
|
||||
model.save_to_disk(logger.log_dir)
|
||||
|
||||
# Evaluate It
|
||||
@ -99,8 +98,7 @@ def run_lightning_loop(config_obj):
|
||||
outputs.append(
|
||||
model.validation_step((batch_x, label), idx, 1)
|
||||
)
|
||||
summary_dict = model.validation_epoch_end([outputs])
|
||||
print(summary_dict['log']['uar_score'])
|
||||
model.validation_epoch_end([outputs])
|
||||
|
||||
# trainer.test()
|
||||
outpath = Path(config_obj.train.outpath)
|
||||
@ -132,6 +130,6 @@ if __name__ == "__main__":
|
||||
|
||||
from _paramters import main_arg_parser
|
||||
|
||||
config = MConfig.read_argparser(main_arg_parser)
|
||||
config = Config.read_argparser(main_arg_parser)
|
||||
fix_all_random_seeds(config)
|
||||
trained_model = run_lightning_loop(config)
|
||||
|
Reference in New Issue
Block a user