From 6cc978e4647151bca2c81f4c9c6b7938da2f2b3a Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Mon, 9 Mar 2020 21:41:50 +0100 Subject: [PATCH] train running --- datasets/trajectory_dataset.py | 14 +++------ lib/models/generators/cnn.py | 33 +++++++++++++------ lib/modules/utils.py | 2 +- lib/objects/map.py | 49 +++++++++++++---------------- lib/objects/trajectory.py | 6 ++++ lib/visualization/generator_eval.py | 12 +++---- main.py | 9 +++--- 7 files changed, 68 insertions(+), 57 deletions(-) diff --git a/datasets/trajectory_dataset.py b/datasets/trajectory_dataset.py index 9ecdd87..4c0c7b2 100644 --- a/datasets/trajectory_dataset.py +++ b/datasets/trajectory_dataset.py @@ -46,7 +46,6 @@ class TrajDataset(Dataset): while True: trajectory = self.map.get_random_trajectory() - # TODO: Sanity Check this while true loop... alternative = self.map.generate_alternative(trajectory) label = self.map.are_homotopic(trajectory, alternative) if self.preserve_equal_samples and label == self.last_label: @@ -56,18 +55,13 @@ class TrajDataset(Dataset): self.last_label = label if self.mode.lower() in ['all_in_map', 'separated_arrays']: - blank_trajectory_space = torch.zeros(self.map.shape) - blank_alternative_space = torch.zeros(self.map.shape) - for index in trajectory.vertices: - blank_trajectory_space[index] = 1 - for index in alternative.vertices: - blank_alternative_space[index] = 1 - map_array = torch.as_tensor(self.map.as_array).float() if self.mode == 'separated_arrays': - return (map_array, blank_trajectory_space, int(label)), blank_alternative_space + return (map_array, trajectory.draw_in_array(self.map_shape), int(label)), \ + alternative.draw_in_array(self.map_shape) else: - return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label) + return torch.cat((map_array, trajectory.draw_in_array(self.map_shape), + alternative.draw_in_array(self.map_shape))), int(label) elif self.mode == 'vectors': return trajectory.vertices, alternative.vertices, label, self.mapname diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py index 3b6d6ce..e938425 100644 --- a/lib/models/generators/cnn.py +++ b/lib/models/generators/cnn.py @@ -41,7 +41,7 @@ class CNNRouteGeneratorModel(LightningBaseModule): # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing - kld_loss /= self.in_shape + kld_loss /= reduce(mul, self.in_shape) loss = (kld_loss + discriminated_bce_loss) / 2 return dict(loss=loss, log=dict(loss=loss, @@ -65,7 +65,10 @@ class CNNRouteGeneratorModel(LightningBaseModule): def validation_step(self, *args): return self._test_val_step(*args) - def validation_epoch_end(self, outputs): + def validation_epoch_end(self, outputs: list): + return self._test_val_epoch_end(outputs) + + def _test_val_epoch_end(self, outputs, test=False): evaluation = ROCEvaluation(plot_roc=True) pred_label = torch.cat([x['pred_label'] for x in outputs]) labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1) @@ -73,8 +76,9 @@ class CNNRouteGeneratorModel(LightningBaseModule): # Sci-py call ROC eval call is eval(true_label, prediction) roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), ) - # self.logger.log_metrics(score_dict) - self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf()) + if test: + # self.logger.log_metrics(score_dict) + self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf()) plt.clf() maps, trajectories, labels, val_restul_dict = self.generate_random() @@ -88,6 +92,10 @@ class CNNRouteGeneratorModel(LightningBaseModule): def test_step(self, *args): return self._test_val_step(*args) + def test_epoch_end(self, outputs): + return self._test_val_epoch_end(outputs, test=True) + + @property def discriminator(self): if self._disc is None: @@ -247,8 +255,15 @@ class CNNRouteGeneratorModel(LightningBaseModule): return z, mu, logvar def generate_random(self, n=6): - maps = [self.map_storage[choice(self.map_storage.keys)] for _ in range(n)] - trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2) - maps = torch.stack([x.as_2d_array for x in maps] * 2) - labels = torch.as_tensor([0] * n + [1] * n) - return maps, trajectories, labels, self._test_val_step(maps, trajectories, labels) + maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)] + + trajectories = [x.get_random_trajectory() for x in maps] * 2 + trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories] + trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] + trajectories = self._move_to_model_device(torch.stack(trajectories)) + + maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2 + maps = self._move_to_model_device(torch.stack(maps)) + + labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n)) + return maps, trajectories, labels, self._test_val_step(([maps, trajectories], labels), -9999) diff --git a/lib/modules/utils.py b/lib/modules/utils.py index 0fb28de..1aa9aab 100644 --- a/lib/modules/utils.py +++ b/lib/modules/utils.py @@ -90,7 +90,7 @@ class LightningBaseModule(pl.LightningModule, ABC): # Data loading # ============================================================================= # Map Object - self.map_storage = MapStorage(self.hparams.data_param.map_root, load_all=True) + self.map_storage = MapStorage(self.hparams.data_param.map_root) def size(self): return self.shape diff --git a/lib/objects/map.py b/lib/objects/map.py index 1898ee2..c15cbee 100644 --- a/lib/objects/map.py +++ b/lib/objects/map.py @@ -1,4 +1,5 @@ import shelve +from collections import UserDict from pathlib import Path import copy @@ -69,7 +70,7 @@ class Map(object): # Check pixels for their color (determine if walkable) for idx, value in np.ndenumerate(self.map_array): - if value == self.white: + if value != self.black: # IF walkable, add node graph.add_node(idx, count=0) # Fully connect to all surrounding neighbors @@ -90,6 +91,7 @@ class Map(object): if image.mode != 'L': image = image.convert('L') map_array = np.expand_dims(np.array(image), axis=0) + map_array = np.where(np.asarray(map_array) == cls.white, 1, 0) if embedding_size: assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}' embedding = np.zeros(embedding_size) @@ -147,9 +149,9 @@ class Map(object): img = Image.new('L', (self.height, self.width), 0) draw = ImageDraw.Draw(img) - draw.polygon(polyline, outline=self.white, fill=self.white) + draw.polygon(polyline, outline=1, fill=1) - a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.black, 1, 0)).sum() + a = (np.asarray(img) * np.where(self.as_2d_array == self.black, 1, 0)).sum() if a: return V.ALTERNATIVE # Non-Homotoph @@ -165,32 +167,25 @@ class Map(object): return dict(img=img, fig=fig, ax=ax) -class MapStorage(object): +class MapStorage(UserDict): @property - def keys(self): - return list(self.data.keys()) + def keys_list(self): + return list(super(MapStorage, self).keys()) - def __init__(self, map_root, load_all=False): - self.data = dict() + def __init__(self, map_root, *args, **kwargs): + super(MapStorage, self).__init__(*args, **kwargs) self.map_root = Path(map_root) - if load_all: - for map_file in self.map_root.glob('*.bmp'): - _ = self[map_file.name] - - def __getitem__(self, item): - if item in self.data.keys(): - return self.data.get(item) - else: - current_map = Map().from_image(self.map_root / item) - self.data.__setitem__(item, np.asarray(current_map)) - return self[item] - - - - - - - - + map_files = list(self.map_root.glob('*.bmp')) + self.max_map_size = (1, ) + tuple( + reversed( + tuple( + map( + max, *[Image.open(map_file).size for map_file in map_files]) + ) + ) + ) + for map_file in map_files: + current_map = Map().from_image(map_file, embedding_size=self.max_map_size) + self.__setitem__(map_file.name, current_map) diff --git a/lib/objects/trajectory.py b/lib/objects/trajectory.py index 2492529..b2f99c0 100644 --- a/lib/objects/trajectory.py +++ b/lib/objects/trajectory.py @@ -41,6 +41,12 @@ class Trajectory(object): def as_paired_list(self): return list(zip(self._vertices[:-1], self._vertices[1:])) + def draw_in_array(self, shape): + trajectory_space = np.zeros(shape) + for index in self.vertices: + trajectory_space[index] = 1 + return trajectory_space + @property def np_vertices(self): return [np.array(vertice) for vertice in self._vertices] diff --git a/lib/visualization/generator_eval.py b/lib/visualization/generator_eval.py index e370f79..1ee7c07 100644 --- a/lib/visualization/generator_eval.py +++ b/lib/visualization/generator_eval.py @@ -19,9 +19,9 @@ class GeneratorVisualizer(object): def _build_column_dict_list(self): dict_list = [] - for idx in range(self.maps): - image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives - label = self.labels[idx] + for idx in range(self.maps.shape[0]): + image = (self.maps[idx] + self.trajectories[idx] + self.generated_alternatives[idx]).cpu().numpy().squeeze() + label = int(self.labels[idx]) dict_list.append(dict(image=image, label=label)) half_size = int(len(dict_list) // 2) return dict_list[:half_size], dict_list[half_size:] @@ -33,10 +33,10 @@ class GeneratorVisualizer(object): axes_pad=0.2, # pad between axes in inch. ) - for idx in grid.axes_all: + for idx in range(len(grid.axes_all)): row, col = divmod(idx, len(self.column_dict_list)) - current_image = self.column_dict_list[col]['image'][row] - current_label = self.column_dict_list[col]['label'][row] + current_image = self.column_dict_list[col][row]['image'] + current_label = self.column_dict_list[col][row]['label'] grid[idx].imshow(current_image) grid[idx].title.set_text(current_label) fig.cbar_mode = 'single' diff --git a/main.py b/main.py index ea11079..f5b8787 100644 --- a/main.py +++ b/main.py @@ -28,12 +28,12 @@ main_arg_parser = ArgumentParser(description="parser for fast-neural-style") # Main Parameters main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="") -main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help="") +main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="") main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") # Data Parameters main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") -main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="") +main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000, help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="") @@ -43,7 +43,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa # Transformations main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") -main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="") +main_arg_parser.add_argument("--train_epochs", type=int, default=12, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="") main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="") @@ -123,7 +123,8 @@ def run_lightning_loop(config_obj): model.save_to_disk(logger.log_dir) # Evaluate It - trainer.test() + if config_obj.main.eval: + trainer.test() return model