diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index e7e9d11..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -# Default ignored files -/workspace.xml diff --git a/.idea/deployment.xml b/.idea/deployment.xml deleted file mode 100644 index ac729da..0000000 --- a/.idea/deployment.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/dictionaries/steffen.xml b/.idea/dictionaries/steffen.xml deleted file mode 100644 index 3cf6c83..0000000 --- a/.idea/dictionaries/steffen.xml +++ /dev/null @@ -1,23 +0,0 @@ - - - - autopad - conv - convolutional - dataloader - dataloaders - datasets - homotopic - hparams - hyperparamter - kingma - logvar - mapname - mapnames - numlayers - reparameterize - softmax - traj - - - \ No newline at end of file diff --git a/.idea/hom_traj_gen.iml b/.idea/hom_traj_gen.iml deleted file mode 100644 index 241d6f7..0000000 --- a/.idea/hom_traj_gen.iml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index dd4c951..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index 0e02653..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,10 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 0b3a4df..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 94a25f7..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/webResources.xml b/.idea/webResources.xml deleted file mode 100644 index aac35f8..0000000 --- a/.idea/webResources.xml +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py index f5634c7..3b6d6ce 100644 --- a/lib/models/generators/cnn.py +++ b/lib/models/generators/cnn.py @@ -1,3 +1,5 @@ +from statistics import mean + from random import choice import torch @@ -65,13 +67,12 @@ class CNNRouteGeneratorModel(LightningBaseModule): def validation_epoch_end(self, outputs): evaluation = ROCEvaluation(plot_roc=True) - predictions = torch.cat([x['prediction'] for x in outputs]) + pred_label = torch.cat([x['pred_label'] for x in outputs]) labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1) - losses = torch.cat([x['discriminated_bce_loss'] for x in outputs]).unsqueeze(1) - mean_losses = losses.mean() + mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean() # Sci-py call ROC eval call is eval(true_label, prediction) - roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), ) + roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), ) # self.logger.log_metrics(score_dict) self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf()) plt.clf() @@ -103,7 +104,7 @@ class CNNRouteGeneratorModel(LightningBaseModule): # Dataset self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', - length=self.hparams.train_param.batch_size * 1000) + length=self.hparams.data_param.dataset_length) # Additional Attributes self.in_shape = self.dataset.map_shapes_max @@ -159,6 +160,10 @@ class CNNRouteGeneratorModel(LightningBaseModule): self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim) + # + # Mixed Encoder + self.mixed_lin = nn.Linear(self.lat_dim, self.lat_dim) + # # Variational Bottleneck self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim) @@ -242,7 +247,7 @@ class CNNRouteGeneratorModel(LightningBaseModule): return z, mu, logvar def generate_random(self, n=6): - maps = [self.map_storage[choice(self.map_storage.keys())] for _ in range(n)] + maps = [self.map_storage[choice(self.map_storage.keys)] for _ in range(n)] trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2) maps = torch.stack([x.as_2d_array for x in maps] * 2) labels = torch.as_tensor([0] * n + [1] * n) diff --git a/lib/models/homotopy_classification/cnn_based.py b/lib/models/homotopy_classification/cnn_based.py index d3d73db..befa8d5 100644 --- a/lib/models/homotopy_classification/cnn_based.py +++ b/lib/models/homotopy_classification/cnn_based.py @@ -57,7 +57,8 @@ class ConvHomDetector(LightningBaseModule): # Model Parameters self.in_shape = self.dataset.map_shapes_max assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}' - self.criterion = nn.BCEWithLogitsLoss() + self.criterion = nn.BCELoss() + self.sigmoid = nn.Sigmoid() # NN Nodes # ============================ @@ -100,4 +101,5 @@ class ConvHomDetector(LightningBaseModule): tensor = self.flatten(tensor) tensor = self.linear(tensor) tensor = self.classifier(tensor) + tensor = self.sigmoid(tensor) return tensor diff --git a/lib/modules/utils.py b/lib/modules/utils.py index 5c9225a..0fb28de 100644 --- a/lib/modules/utils.py +++ b/lib/modules/utils.py @@ -90,7 +90,7 @@ class LightningBaseModule(pl.LightningModule, ABC): # Data loading # ============================================================================= # Map Object - self.map_storage = MapStorage(self.hparams.data_param.map_root) + self.map_storage = MapStorage(self.hparams.data_param.map_root, load_all=True) def size(self): return self.shape @@ -143,19 +143,19 @@ class LightningBaseModule(pl.LightningModule, ABC): # Train Dataloader def train_dataloader(self): return DataLoader(dataset=self.dataset.train_dataset, shuffle=True, - batch_size=self.hparams.data_param.batchsize, + batch_size=self.hparams.train_param.batch_size, num_workers=self.hparams.data_param.worker) # Test Dataloader def test_dataloader(self): return DataLoader(dataset=self.dataset.test_dataset, shuffle=True, - batch_size=self.hparams.data_param.batchsize, + batch_size=self.hparams.train_param.batch_size, num_workers=self.hparams.data_param.worker) # Validation Dataloader def val_dataloader(self): return DataLoader(dataset=self.dataset.val_dataset, shuffle=False, - batch_size=self.hparams.data_param.batchsize, + batch_size=self.hparams.train_param.batch_size, num_workers=self.hparams.data_param.worker) diff --git a/lib/objects/map.py b/lib/objects/map.py index 3c473fd..1898ee2 100644 --- a/lib/objects/map.py +++ b/lib/objects/map.py @@ -167,6 +167,10 @@ class Map(object): class MapStorage(object): + @property + def keys(self): + return list(self.data.keys()) + def __init__(self, map_root, load_all=False): self.data = dict() self.map_root = Path(map_root) @@ -175,11 +179,11 @@ class MapStorage(object): _ = self[map_file.name] def __getitem__(self, item): - if item in hasattr(self, item): - return self.__getattribute__(item) + if item in self.data.keys(): + return self.data.get(item) else: - with shelve.open(self.map_root / f'{item}.pik', flag='r') as d: - self.__setattr__(item, d['map']['map']) + current_map = Map().from_image(self.map_root / item) + self.data.__setitem__(item, np.asarray(current_map)) return self[item] diff --git a/main.py b/main.py index c6b7925..ea11079 100644 --- a/main.py +++ b/main.py @@ -33,6 +33,7 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") # Data Parameters main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") +main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="") @@ -105,6 +106,7 @@ def run_lightning_loop(config_obj): show_progress_bar=True, weights_save_path=logger.log_dir, gpus=[0] if torch.cuda.is_available() else None, + check_val_every_n_epoch=1, # row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting # log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting checkpoint_callback=checkpoint_callback,