all debug and train running

This commit is contained in:
steffen
2020-03-04 20:03:16 +01:00
parent 26dc052e33
commit 3a6c65240f
2 changed files with 64 additions and 37 deletions

@ -71,5 +71,25 @@ class Logger(LightningLoggerBase):
self.testtubelogger.log_metrics(metrics, step=step)
pass
def close(self):
self.testtubelogger.close()
self.neptunelogger.close()
def log_config_as_ini(self):
self.config.write(self.log_dir)
def save(self):
self.testtubelogger.save()
self.neptunelogger.save()
def finalize(self, status):
self.testtubelogger.finalize()
self.neptunelogger.finalize()
self.log_config_as_ini()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.finalize('success')
pass

81
main.py

@ -9,7 +9,7 @@ import warnings
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from lib.modules.utils import LightningBaseModule
from lib.utils.config import Config
@ -43,7 +43,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=512, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
# Model
@ -64,47 +64,54 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
args = main_arg_parser.parse_args()
config = Config.read_namespace(args)
# Logger
# =============================================================================
logger = Logger(config)
# Checkpoint Callback
# =============================================================================
checkpoint_callback = ModelCheckpoint(
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True,
period=1
)
if __name__ == "__main__":
# Model
# Logging
# =============================================================================
# Init
model: LightningBaseModule = config.model_class(config.model_paramters)
model.init_weights()
# Logger
with Logger(config) as logger:
# Callbacks
# =============================================================================
# Checkpoint Saving
checkpoint_callback = ModelCheckpoint(
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True, save_top_k=5,
)
# =============================================================================
# Early Stopping
# TODO: For This to work, one must set a validation step and End Eval and Score
early_stopping_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.0,
patience=0,
)
# Trainer
# =============================================================================
trainer = Trainer(max_nb_epochs=config.train.epochs,
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
row_log_interval=model.data_len // 40, # TODO: Better Value / Setting
log_save_interval=model.data_len // 10, # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,
logger=logger,
fast_dev_run=config.main.debug,
early_stop_callback=None
)
# Model
# =============================================================================
# Init
model: LightningBaseModule = config.model_class(config.model_paramters)
model.init_weights()
# Train it
trainer.fit(model)
# Trainer
# =============================================================================
trainer = Trainer(max_epochs=config.train.epochs,
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
row_log_interval=(model.data_len * 0.01), # TODO: Better Value / Setting
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,
logger=logger,
fast_dev_run=config.main.debug,
early_stop_callback=None
)
# Save the last state & all parameters
config.exp_path.mkdir(parents=True, exist_ok=True) # Todo: do i need this?
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
model.save_to_disk(logger.log_dir)
# Train it
trainer.fit(model)
# Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
model.save_to_disk(logger.log_dir)
pass
# TODO: Eval here!