diff --git a/.idea/.gitignore b/.idea/.gitignore
deleted file mode 100644
index e7e9d11..0000000
--- a/.idea/.gitignore
+++ /dev/null
@@ -1,2 +0,0 @@
-# Default ignored files
-/workspace.xml
diff --git a/.idea/deployment.xml b/.idea/deployment.xml
deleted file mode 100644
index ac729da..0000000
--- a/.idea/deployment.xml
+++ /dev/null
@@ -1,22 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/dictionaries/steffen.xml b/.idea/dictionaries/steffen.xml
deleted file mode 100644
index 3cf6c83..0000000
--- a/.idea/dictionaries/steffen.xml
+++ /dev/null
@@ -1,23 +0,0 @@
-
-
-
- autopad
- conv
- convolutional
- dataloader
- dataloaders
- datasets
- homotopic
- hparams
- hyperparamter
- kingma
- logvar
- mapname
- mapnames
- numlayers
- reparameterize
- softmax
- traj
-
-
-
\ No newline at end of file
diff --git a/.idea/hom_traj_gen.iml b/.idea/hom_traj_gen.iml
deleted file mode 100644
index 241d6f7..0000000
--- a/.idea/hom_traj_gen.iml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
deleted file mode 100644
index dd4c951..0000000
--- a/.idea/inspectionProfiles/profiles_settings.xml
+++ /dev/null
@@ -1,7 +0,0 @@
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
deleted file mode 100644
index 0e02653..0000000
--- a/.idea/misc.xml
+++ /dev/null
@@ -1,10 +0,0 @@
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
deleted file mode 100644
index 0b3a4df..0000000
--- a/.idea/modules.xml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
deleted file mode 100644
index 94a25f7..0000000
--- a/.idea/vcs.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/webResources.xml b/.idea/webResources.xml
deleted file mode 100644
index aac35f8..0000000
--- a/.idea/webResources.xml
+++ /dev/null
@@ -1,15 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py
index f5634c7..3b6d6ce 100644
--- a/lib/models/generators/cnn.py
+++ b/lib/models/generators/cnn.py
@@ -1,3 +1,5 @@
+from statistics import mean
+
from random import choice
import torch
@@ -65,13 +67,12 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def validation_epoch_end(self, outputs):
evaluation = ROCEvaluation(plot_roc=True)
- predictions = torch.cat([x['prediction'] for x in outputs])
+ pred_label = torch.cat([x['pred_label'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
- losses = torch.cat([x['discriminated_bce_loss'] for x in outputs]).unsqueeze(1)
- mean_losses = losses.mean()
+ mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
# Sci-py call ROC eval call is eval(true_label, prediction)
- roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
+ roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
plt.clf()
@@ -103,7 +104,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
- length=self.hparams.train_param.batch_size * 1000)
+ length=self.hparams.data_param.dataset_length)
# Additional Attributes
self.in_shape = self.dataset.map_shapes_max
@@ -159,6 +160,10 @@ class CNNRouteGeneratorModel(LightningBaseModule):
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
+ #
+ # Mixed Encoder
+ self.mixed_lin = nn.Linear(self.lat_dim, self.lat_dim)
+
#
# Variational Bottleneck
self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
@@ -242,7 +247,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
return z, mu, logvar
def generate_random(self, n=6):
- maps = [self.map_storage[choice(self.map_storage.keys())] for _ in range(n)]
+ maps = [self.map_storage[choice(self.map_storage.keys)] for _ in range(n)]
trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2)
maps = torch.stack([x.as_2d_array for x in maps] * 2)
labels = torch.as_tensor([0] * n + [1] * n)
diff --git a/lib/models/homotopy_classification/cnn_based.py b/lib/models/homotopy_classification/cnn_based.py
index d3d73db..befa8d5 100644
--- a/lib/models/homotopy_classification/cnn_based.py
+++ b/lib/models/homotopy_classification/cnn_based.py
@@ -57,7 +57,8 @@ class ConvHomDetector(LightningBaseModule):
# Model Parameters
self.in_shape = self.dataset.map_shapes_max
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
- self.criterion = nn.BCEWithLogitsLoss()
+ self.criterion = nn.BCELoss()
+ self.sigmoid = nn.Sigmoid()
# NN Nodes
# ============================
@@ -100,4 +101,5 @@ class ConvHomDetector(LightningBaseModule):
tensor = self.flatten(tensor)
tensor = self.linear(tensor)
tensor = self.classifier(tensor)
+ tensor = self.sigmoid(tensor)
return tensor
diff --git a/lib/modules/utils.py b/lib/modules/utils.py
index 5c9225a..0fb28de 100644
--- a/lib/modules/utils.py
+++ b/lib/modules/utils.py
@@ -90,7 +90,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Data loading
# =============================================================================
# Map Object
- self.map_storage = MapStorage(self.hparams.data_param.map_root)
+ self.map_storage = MapStorage(self.hparams.data_param.map_root, load_all=True)
def size(self):
return self.shape
@@ -143,19 +143,19 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
- batch_size=self.hparams.data_param.batchsize,
+ batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
- batch_size=self.hparams.data_param.batchsize,
+ batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
- batch_size=self.hparams.data_param.batchsize,
+ batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
diff --git a/lib/objects/map.py b/lib/objects/map.py
index 3c473fd..1898ee2 100644
--- a/lib/objects/map.py
+++ b/lib/objects/map.py
@@ -167,6 +167,10 @@ class Map(object):
class MapStorage(object):
+ @property
+ def keys(self):
+ return list(self.data.keys())
+
def __init__(self, map_root, load_all=False):
self.data = dict()
self.map_root = Path(map_root)
@@ -175,11 +179,11 @@ class MapStorage(object):
_ = self[map_file.name]
def __getitem__(self, item):
- if item in hasattr(self, item):
- return self.__getattribute__(item)
+ if item in self.data.keys():
+ return self.data.get(item)
else:
- with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
- self.__setattr__(item, d['map']['map'])
+ current_map = Map().from_image(self.map_root / item)
+ self.data.__setitem__(item, np.asarray(current_map))
return self[item]
diff --git a/main.py b/main.py
index c6b7925..ea11079 100644
--- a/main.py
+++ b/main.py
@@ -33,6 +33,7 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
+main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
@@ -105,6 +106,7 @@ def run_lightning_loop(config_obj):
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
+ check_val_every_n_epoch=1,
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,