From 1f25bf599b7b3d5ea957946ec8e38dc785fa2c46 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Thu, 5 Mar 2020 16:58:23 +0100 Subject: [PATCH] eval written --- .idea/deployment.xml | 11 +++++- .idea/hom_traj_gen.iml | 2 +- .idea/misc.xml | 2 +- lib/evaluation/classification.py | 21 ++++------ lib/models/generators/cnn.py | 2 +- .../homotopy_classification/cnn_based.py | 38 +++++++++++++++++++ lib/modules/utils.py | 18 --------- lib/utils/config.py | 34 ++++++++++++----- lib/utils/logging.py | 8 ++++ main.py | 38 ++++++++++++------- main_post.py | 0 multi_run.py | 27 ++++++------- 12 files changed, 127 insertions(+), 74 deletions(-) delete mode 100644 main_post.py diff --git a/.idea/deployment.xml b/.idea/deployment.xml index b9d4963..533fe8a 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,8 +1,15 @@ - + - + + + + + + + + diff --git a/.idea/hom_traj_gen.iml b/.idea/hom_traj_gen.iml index 241d6f7..4b1d9c2 100644 --- a/.idea/hom_traj_gen.iml +++ b/.idea/hom_traj_gen.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 06cb946..f164374 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/lib/evaluation/classification.py b/lib/evaluation/classification.py index 4e56c10..9fb8063 100644 --- a/lib/evaluation/classification.py +++ b/lib/evaluation/classification.py @@ -1,29 +1,24 @@ import matplotlib.pyplot as plt - from sklearn.metrics import roc_curve, auc class ROCEvaluation(object): - BINARY_PROBLEM = 2 linewidth = 2 - def __init__(self, save_fig=True): + def __init__(self, prepare_figure=False): + self.prepare_figure = prepare_figure self.epoch = 0 - pass - def __call__(self, prediction, label, prepare_fig=True): + def __call__(self, prediction, label, plotting=False): # Compute ROC curve and ROC area fpr, tpr, _ = roc_curve(prediction, label) roc_auc = auc(fpr, tpr) - - if prepare_fig: - fig = self._prepare_fig() - fig.plot(fpr, tpr, color='darkorange', - lw=2, label=f'ROC curve (area = {roc_auc})') - self._prepare_fig() - return roc_auc + if plotting: + fig = plt.gcf() + fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})') + return roc_auc, fpr, tpr def _prepare_fig(self): fig = plt.gcf() @@ -32,6 +27,6 @@ class ROCEvaluation(object): fig.ylim([0.0, 1.05]) fig.xlabel('False Positive Rate') fig.ylabel('True Positive Rate') - fig.legend(loc="lower right") + return fig diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py index fa04f4a..df6fe03 100644 --- a/lib/models/generators/cnn.py +++ b/lib/models/generators/cnn.py @@ -1,4 +1,4 @@ -from dataset.dataset import TrajPairData +from datasets.paired_dataset import TrajPairData from lib.modules.blocks import ConvModule from lib.modules.utils import LightningBaseModule diff --git a/lib/models/homotopy_classification/cnn_based.py b/lib/models/homotopy_classification/cnn_based.py index 201892d..e5216cb 100644 --- a/lib/models/homotopy_classification/cnn_based.py +++ b/lib/models/homotopy_classification/cnn_based.py @@ -5,8 +5,10 @@ import torch from torch import nn import torch.nn.functional as F from torch.optim import Adam +from torch.utils.data import DataLoader from datasets.trajectory_dataset import TrajData +from lib.evaluation.classification import ROCEvaluation from lib.modules.utils import LightningBaseModule, Flatten from lib.modules.blocks import ConvModule, ResidualModule @@ -24,6 +26,22 @@ class ConvHomDetector(LightningBaseModule): loss = F.binary_cross_entropy(pred_y, batch_y.float()) return {'loss': loss, 'log': dict(loss=loss)} + def test_step(self, batch_xy, **kwargs): + batch_x, batch_y = batch_xy + pred_y = self(batch_x) + return dict(prediction=pred_y, label=batch_y) + + def test_end(self, outputs): + evaluation = ROCEvaluation() + predictions = torch.stack([x['prediction'] for x in outputs]) + labels = torch.stack([x['label'] for x in outputs]) + + scores = evaluation(predictions.numpy(), labels.numpy()) + self.logger.log_metrics() + + + pass + def __init__(self, *params): super(ConvHomDetector, self).__init__(*params) @@ -70,6 +88,26 @@ class ConvHomDetector(LightningBaseModule): self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes) self.out_activation = nn.Sigmoid() # nn.Softmax + # Dataloaders + # ================================================================================ + # Train Dataloader + def train_dataloader(self): + return DataLoader(dataset=self.dataset.train_dataset, shuffle=True, + batch_size=self.hparams.data_param.batchsize, + num_workers=self.hparams.data_param.worker) + + # Test Dataloader + def test_dataloader(self): + return DataLoader(dataset=self.dataset.test_dataset, shuffle=True, + batch_size=self.hparams.data_param.batchsize, + num_workers=self.hparams.data_param.worker) + + # Validation Dataloader + def val_dataloader(self): + return DataLoader(dataset=self.dataset.val_dataset, shuffle=True, + batch_size=self.hparams.data_param.batchsize, + num_workers=self.hparams.data_param.worker) + def forward(self, x): tensor = self.map_conv_0(x) tensor = self.map_res_1(tensor) diff --git a/lib/modules/utils.py b/lib/modules/utils.py index 34d2caf..a29d4d9 100644 --- a/lib/modules/utils.py +++ b/lib/modules/utils.py @@ -105,24 +105,6 @@ class LightningBaseModule(pl.LightningModule, ABC): torch.save(self.__class__, f) return True - @pl.data_loader - def train_dataloader(self): - return DataLoader(dataset=self.dataset.train_dataset, shuffle=True, - batch_size=self.hparams.data_param.batchsize, - num_workers=self.hparams.data_param.worker) - - @pl.data_loader - def test_dataloader(self): - return DataLoader(dataset=self.dataset.test_dataset, shuffle=True, - batch_size=self.hparams.data_param.batchsize, - num_workers=self.hparams.data_param.worker) - - @pl.data_loader - def val_dataloader(self): - return DataLoader(dataset=self.dataset.val_dataset, shuffle=True, - batch_size=self.hparams.data_param.batchsize, - num_workers=self.hparams.data_param.worker) - @property def data_len(self): return len(self.dataset.train_dataset) diff --git a/lib/utils/config.py b/lib/utils/config.py index e02f627..0fe8910 100644 --- a/lib/utils/config.py +++ b/lib/utils/config.py @@ -79,20 +79,37 @@ class Config(ConfigParser): super(Config, self).__init__(**kwargs) pass + @staticmethod + def _sort_combined_section_key_mapping(dict_obj): + sorted_dict = defaultdict(dict) + for key in dict_obj: + section, *attr_name = key.split('_') + attr_name = '_'.join(attr_name) + value = str(dict_obj[key]) + + sorted_dict[section][attr_name] = value + # noinspection PyTypeChecker + return dict(sorted_dict) + @classmethod def read_namespace(cls, namespace: Namespace): - space_dict = defaultdict(dict) - for key in namespace.__dict__: - section, *attr_name = key.split('_') - attr_name = '_'.join(attr_name) - value = str(namespace.__getattribute__(key)) - - space_dict[section][attr_name] = value + sorted_dict = cls._sort_combined_section_key_mapping(namespace.__dict__) new_config = cls() - new_config.read_dict(space_dict) + new_config.read_dict(sorted_dict) return new_config + def update(self, mapping): + sorted_dict = self._sort_combined_section_key_mapping(mapping) + for section in sorted_dict: + if self.has_section(section): + pass + else: + self.add_section(section) + for option, value in sorted_dict[section].items(): + self.set(section, option, value) + return self + def get(self, *args, **kwargs): item = super(Config, self).get(*args, **kwargs) try: @@ -108,5 +125,4 @@ class Config(ConfigParser): with path.open('w') as configfile: super().write(configfile) - return True diff --git a/lib/utils/logging.py b/lib/utils/logging.py index 359d7d5..c38b5f9 100644 --- a/lib/utils/logging.py +++ b/lib/utils/logging.py @@ -78,6 +78,14 @@ class Logger(LightningLoggerBase): def log_config_as_ini(self): self.config.write(self.log_dir) + def log_metric(self, metric_name, metric_value, **kwargs): + self.testtubelogger.log_metrics(dict(metric_name=metric_value)) + self.neptunelogger.log_metric(metric_name, metric_value, **kwargs) + + def log_image(self, name, image, **kwargs): + self.neptunelogger.log_image(name, image, **kwargs) + image.savefig(self.log_dir / name) + def save(self): self.testtubelogger.save() self.neptunelogger.save() diff --git a/main.py b/main.py index ca055a6..d7ebca0 100644 --- a/main.py +++ b/main.py @@ -3,18 +3,21 @@ import os from distutils.util import strtobool from pathlib import Path -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace import warnings import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from torch.utils.data import DataLoader from lib.modules.utils import LightningBaseModule from lib.utils.config import Config from lib.utils.logging import Logger +from lib.evaluation.classification import ROCEvaluation + warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) @@ -31,7 +34,7 @@ main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help= main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") # Data Parameters -main_arg_parser.add_argument("--data_worker", type=int, default=0, help="") +main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="") main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="") main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="") @@ -61,16 +64,15 @@ main_arg_parser.add_argument("--project_owner", type=str, default='si11ium', hel main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.getenv('NEPTUNE_KEY'), help="") # Parse it -args = main_arg_parser.parse_args() -config = Config.read_namespace(args) +args: Namespace = main_arg_parser.parse_args() -if __name__ == "__main__": +def run_lightning_loop(config_obj): # Logging - # ============================================================================= + # ================================================================================ # Logger - with Logger(config) as logger: + with Logger(config_obj) as logger: # Callbacks # ============================================================================= # Checkpoint Saving @@ -90,12 +92,12 @@ if __name__ == "__main__": # Model # ============================================================================= # Init - model: LightningBaseModule = config.model_class(config.model_paramters) + model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters) model.init_weights() # Trainer # ============================================================================= - trainer = Trainer(max_epochs=config.train.epochs, + trainer = Trainer(max_epochs=config_obj.train.epochs, show_progress_bar=True, weights_save_path=logger.log_dir, gpus=[0] if torch.cuda.is_available() else None, @@ -103,15 +105,23 @@ if __name__ == "__main__": log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting checkpoint_callback=checkpoint_callback, logger=logger, - fast_dev_run=config.main.debug, + fast_dev_run=config_obj.main.debug, early_stop_callback=None ) - # Train it - trainer.fit(model) + # Train It + trainer.fit(model,) # Save the last state & all parameters trainer.save_checkpoint(logger.log_dir / 'weights.ckpt') model.save_to_disk(logger.log_dir) - pass - # TODO: Eval here! + + # Evaluate It + trainer.test() + return model + + +if __name__ == "__main__": + + config = Config.read_namespace(args) + trained_model = run_lightning_loop(config) diff --git a/main_post.py b/main_post.py deleted file mode 100644 index e69de29..0000000 diff --git a/multi_run.py b/multi_run.py index cf0eaef..6e9dd2f 100644 --- a/multi_run.py +++ b/multi_run.py @@ -7,29 +7,26 @@ warnings.filterwarnings('ignore', category=UserWarning) # Imports # ============================================================================= -from pathlib import Path -import os + +from main import run_training, args if __name__ == '__main__': # Model Settings - warnings.filterwarnings('ignore', category=FutureWarning) + config = Config().read_namespace(args) # use_bias, activation, model, use_norm, max_epochs, filters cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]] # use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters - # Data Settings - data_shortcodes = ['mid', 'mid_5'] + for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]: + for seed in range(5): + arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs, + model_use_bias=use_bias, model_use_norm=use_norm, + model_activation=activation, model_type=model, + model_filters=filters, + data_batch_size=512) - # Iteration over - for data_shortcode in data_shortcodes: - for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]: - for seed in range(5): - arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs, - model_use_bias=use_bias, model_use_norm=use_norm, - model_activation=activation, model_type=model, - model_filters=filters, - data_batch_size=512) + config = config.update(arg_dict) - os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}') + run_training(config)