all debug and train running
This commit is contained in:
@ -71,5 +71,25 @@ class Logger(LightningLoggerBase):
|
||||
self.testtubelogger.log_metrics(metrics, step=step)
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.testtubelogger.close()
|
||||
self.neptunelogger.close()
|
||||
|
||||
def log_config_as_ini(self):
|
||||
self.config.write(self.log_dir)
|
||||
|
||||
def save(self):
|
||||
self.testtubelogger.save()
|
||||
self.neptunelogger.save()
|
||||
|
||||
def finalize(self, status):
|
||||
self.testtubelogger.finalize()
|
||||
self.neptunelogger.finalize()
|
||||
self.log_config_as_ini()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.finalize('success')
|
||||
pass
|
||||
|
81
main.py
81
main.py
@ -9,7 +9,7 @@ import warnings
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
|
||||
from lib.modules.utils import LightningBaseModule
|
||||
from lib.utils.config import Config
|
||||
@ -43,7 +43,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
|
||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=512, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||
|
||||
# Model
|
||||
@ -64,47 +64,54 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
|
||||
args = main_arg_parser.parse_args()
|
||||
config = Config.read_namespace(args)
|
||||
|
||||
# Logger
|
||||
# =============================================================================
|
||||
logger = Logger(config)
|
||||
|
||||
# Checkpoint Callback
|
||||
# =============================================================================
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||
verbose=True,
|
||||
period=1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Model
|
||||
# Logging
|
||||
# =============================================================================
|
||||
# Init
|
||||
model: LightningBaseModule = config.model_class(config.model_paramters)
|
||||
model.init_weights()
|
||||
# Logger
|
||||
with Logger(config) as logger:
|
||||
# Callbacks
|
||||
# =============================================================================
|
||||
# Checkpoint Saving
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||
verbose=True, save_top_k=5,
|
||||
)
|
||||
# =============================================================================
|
||||
# Early Stopping
|
||||
# TODO: For This to work, one must set a validation step and End Eval and Score
|
||||
early_stopping_callback = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
min_delta=0.0,
|
||||
patience=0,
|
||||
)
|
||||
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
trainer = Trainer(max_nb_epochs=config.train.epochs,
|
||||
show_progress_bar=True,
|
||||
weights_save_path=logger.log_dir,
|
||||
gpus=[0] if torch.cuda.is_available() else None,
|
||||
row_log_interval=model.data_len // 40, # TODO: Better Value / Setting
|
||||
log_save_interval=model.data_len // 10, # TODO: Better Value / Setting
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
logger=logger,
|
||||
fast_dev_run=config.main.debug,
|
||||
early_stop_callback=None
|
||||
)
|
||||
# Model
|
||||
# =============================================================================
|
||||
# Init
|
||||
model: LightningBaseModule = config.model_class(config.model_paramters)
|
||||
model.init_weights()
|
||||
|
||||
# Train it
|
||||
trainer.fit(model)
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
trainer = Trainer(max_epochs=config.train.epochs,
|
||||
show_progress_bar=True,
|
||||
weights_save_path=logger.log_dir,
|
||||
gpus=[0] if torch.cuda.is_available() else None,
|
||||
row_log_interval=(model.data_len * 0.01), # TODO: Better Value / Setting
|
||||
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
logger=logger,
|
||||
fast_dev_run=config.main.debug,
|
||||
early_stop_callback=None
|
||||
)
|
||||
|
||||
# Save the last state & all parameters
|
||||
config.exp_path.mkdir(parents=True, exist_ok=True) # Todo: do i need this?
|
||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||
model.save_to_disk(logger.log_dir)
|
||||
# Train it
|
||||
trainer.fit(model)
|
||||
|
||||
# Save the last state & all parameters
|
||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||
model.save_to_disk(logger.log_dir)
|
||||
pass
|
||||
# TODO: Eval here!
|
||||
|
Reference in New Issue
Block a user