diff --git a/.idea/deployment.xml b/.idea/deployment.xml index 83f58c5..b9d4963 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,15 +1,8 @@ - + - - - - - - - - + diff --git a/.idea/hom_traj_gen.iml b/.idea/hom_traj_gen.iml index 4b1d9c2..241d6f7 100644 --- a/.idea/hom_traj_gen.iml +++ b/.idea/hom_traj_gen.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index f164374..06cb946 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/datasets/trajectory_dataset.py b/datasets/trajectory_dataset.py index de019d2..69a80de 100644 --- a/datasets/trajectory_dataset.py +++ b/datasets/trajectory_dataset.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import Union, List import torch -from random import choice from torch.utils.data import ConcatDataset, Dataset from lib.objects.map import Map @@ -17,12 +16,14 @@ class TrajDataset(Dataset): return self.map.as_array.shape def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', - length=100000, all_in_map=True, embedding_size=None, **kwargs): + length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs): super(TrajDataset, self).__init__() + self.preserve_equal_samples = preserve_equal_samples self.all_in_map = all_in_map self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.maps_root = maps_root self._len = length + self.last_label = -1 self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size) @@ -31,8 +32,16 @@ class TrajDataset(Dataset): def __getitem__(self, item): trajectory = self.map.get_random_trajectory() - alternative = self.map.generate_alternative(trajectory) - label = choice([0, 1]) + while True: + # TODO: Sanity Check this while true loop... + alternative = self.map.generate_alternative(trajectory) + label = self.map.are_homotopic(trajectory, alternative) + if self.preserve_equal_samples and label == self.last_label: + continue + else: + break + + self.last_label = label if self.all_in_map: blank_trajectory_space = torch.zeros(self.map.shape) blank_alternative_space = torch.zeros(self.map.shape) @@ -41,7 +50,6 @@ class TrajDataset(Dataset): blank_alternative_space[index] = 1 map_array = torch.as_tensor(self.map.as_array).float() - label = self.map.are_homotopic(trajectory, alternative) return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label) else: return trajectory.vertices, alternative.vertices, label, self.mapname @@ -78,7 +86,8 @@ class TrajData(object): # find max image size among available maps: max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files])))) return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, - all_in_map=self.all_in_map, embedding_size=max_map_size) + all_in_map=self.all_in_map, embedding_size=max_map_size, + preserve_equal_samples=True) for map_file in map_files]) @property diff --git a/lib/utils/logging.py b/lib/utils/logging.py index b7032c0..359d7d5 100644 --- a/lib/utils/logging.py +++ b/lib/utils/logging.py @@ -18,9 +18,7 @@ class Logger(LightningLoggerBase): @property def log_dir(self): - if self.debug: - return Path(self.outpath) - return Path(self.experiment.log_dir).parent + return Path(self.testtubelogger.experiment.get_logdir()).parent @property def name(self): @@ -32,14 +30,14 @@ class Logger(LightningLoggerBase): @property def version(self): - return f"version_{self.config.get('main', 'seed')}" + return self.config.get('main', 'seed') @property def outpath(self): # ToDo: Add further path modification such as dataset config etc. return Path(self.config.train.outpath) - def __init__(self, config: Config, debug=False): + def __init__(self, config: Config): """ params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only. Parameters are displayed in the experiment’s Parameters section and each key-value pair can be @@ -53,8 +51,8 @@ class Logger(LightningLoggerBase): """ super(Logger, self).__init__() - self.debug = debug self.config = config + self.debug = self.config.main.debug self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name) self._neptune_kwargs = dict(offline_mode=self.debug, api_key=self.config.project.neptune_key, @@ -68,10 +66,30 @@ class Logger(LightningLoggerBase): self.testtubelogger.log_hyperparams(params) pass - def log_metrics(self, metrics, step_num): - self.neptunelogger.log_metrics(metrics, step_num) - self.testtubelogger.log_metrics(metrics, step_num) + def log_metrics(self, metrics, step=None): + self.neptunelogger.log_metrics(metrics, step=step) + self.testtubelogger.log_metrics(metrics, step=step) pass + def close(self): + self.testtubelogger.close() + self.neptunelogger.close() + def log_config_as_ini(self): self.config.write(self.log_dir) + + def save(self): + self.testtubelogger.save() + self.neptunelogger.save() + + def finalize(self, status): + self.testtubelogger.finalize() + self.neptunelogger.finalize() + self.log_config_as_ini() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.finalize('success') + pass diff --git a/main.py b/main.py index d64b765..ca055a6 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ import warnings import torch from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from lib.modules.utils import LightningBaseModule from lib.utils.config import Config @@ -43,7 +43,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="") -main_arg_parser.add_argument("--train_batch_size", type=int, default=512, help="") +main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="") main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="") # Model @@ -64,47 +64,54 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten args = main_arg_parser.parse_args() config = Config.read_namespace(args) -# Logger -# ============================================================================= -logger = Logger(config, debug=True) - -# Checkpoint Callback -# ============================================================================= -checkpoint_callback = ModelCheckpoint( - filepath=str(logger.log_dir / 'ckpt_weights'), - verbose=True, - period=1 -) - if __name__ == "__main__": - # Model + # Logging # ============================================================================= - # Init - model: LightningBaseModule = config.model_class(config.model_paramters) - model.init_weights() + # Logger + with Logger(config) as logger: + # Callbacks + # ============================================================================= + # Checkpoint Saving + checkpoint_callback = ModelCheckpoint( + filepath=str(logger.log_dir / 'ckpt_weights'), + verbose=True, save_top_k=5, + ) + # ============================================================================= + # Early Stopping + # TODO: For This to work, one must set a validation step and End Eval and Score + early_stopping_callback = EarlyStopping( + monitor='val_loss', + min_delta=0.0, + patience=0, + ) - # Trainer - # ============================================================================= - trainer = Trainer(max_nb_epochs=config.train.epochs, - show_progress_bar=True, - weights_save_path=logger.log_dir, - gpus=[0] if torch.cuda.is_available() else None, - row_log_interval=model.data_len // 40, # TODO: Better Value / Setting - log_save_interval=model.data_len // 10, # TODO: Better Value / Setting - checkpoint_callback=checkpoint_callback, - logger=logger, - fast_dev_run=config.get('main', 'debug'), - early_stop_callback=None - ) + # Model + # ============================================================================= + # Init + model: LightningBaseModule = config.model_class(config.model_paramters) + model.init_weights() - # Train it - trainer.fit(model) + # Trainer + # ============================================================================= + trainer = Trainer(max_epochs=config.train.epochs, + show_progress_bar=True, + weights_save_path=logger.log_dir, + gpus=[0] if torch.cuda.is_available() else None, + row_log_interval=(model.data_len * 0.01), # TODO: Better Value / Setting + log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting + checkpoint_callback=checkpoint_callback, + logger=logger, + fast_dev_run=config.main.debug, + early_stop_callback=None + ) - # Save the last state & all parameters - config.exp_path.mkdir(parents=True, exist_ok=True) # Todo: do i need this? - trainer.save_checkpoint(logger.log_dir / 'weights.ckpt') - model.save_to_disk(logger.log_dir) + # Train it + trainer.fit(model) + # Save the last state & all parameters + trainer.save_checkpoint(logger.log_dir / 'weights.ckpt') + model.save_to_disk(logger.log_dir) + pass # TODO: Eval here! diff --git a/main_post.py b/main_post.py new file mode 100644 index 0000000..e69de29 diff --git a/multi_run.py b/multi_run.py new file mode 100644 index 0000000..cf0eaef --- /dev/null +++ b/multi_run.py @@ -0,0 +1,35 @@ +import warnings + +from lib.utils.config import Config + +warnings.filterwarnings('ignore', category=FutureWarning) +warnings.filterwarnings('ignore', category=UserWarning) + +# Imports +# ============================================================================= +from pathlib import Path +import os + + +if __name__ == '__main__': + + # Model Settings + warnings.filterwarnings('ignore', category=FutureWarning) + # use_bias, activation, model, use_norm, max_epochs, filters + cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]] + # use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters + + # Data Settings + data_shortcodes = ['mid', 'mid_5'] + + # Iteration over + for data_shortcode in data_shortcodes: + for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]: + for seed in range(5): + arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs, + model_use_bias=use_bias, model_use_norm=use_norm, + model_activation=activation, model_type=model, + model_filters=filters, + data_batch_size=512) + + os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')