project Refactor, CNN Classifier Basics

This commit is contained in:
Steffen Illium
2020-03-08 23:46:02 +01:00
parent 75e8a61628
commit cd4fdf2de3
20 changed files with 441 additions and 239 deletions

7
.idea/deployment.xml generated

@ -2,6 +2,13 @@
<project version="4">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22" showAutoUploadSettingsWarning="false">
<serverData>
<paths name="erlowa@aimachine">
<serverdata>
<mappings>
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="steffen@aimachine:22">
<serverdata>
<mappings>

@ -1,13 +1,22 @@
<component name="ProjectDictionaryState">
<dictionary name="steffen">
<words>
<w>autopad</w>
<w>conv</w>
<w>convolutional</w>
<w>dataloader</w>
<w>dataloaders</w>
<w>datasets</w>
<w>homotopic</w>
<w>hparams</w>
<w>hyperparamter</w>
<w>kingma</w>
<w>logvar</w>
<w>mapname</w>
<w>mapnames</w>
<w>numlayers</w>
<w>reparameterize</w>
<w>softmax</w>
<w>traj</w>
</words>
</dictionary>

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="hom_traj_gen@aimachine" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

5
.idea/misc.xml generated

@ -3,5 +3,8 @@
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@ai-machine" project-jdk-type="Python SDK" />
<component name="PyPackaging">
<option name="earlyReleasesAsUpgrades" value="true" />
</component>
</project>

@ -3,6 +3,7 @@ from pathlib import Path
from typing import Union, List
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map
@ -16,10 +17,11 @@ class TrajDataset(Dataset):
return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs):
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs):
super(TrajDataset, self).__init__()
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
self.preserve_equal_samples = preserve_equal_samples
self.all_in_map = all_in_map
self.mode = mode
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
@ -31,8 +33,19 @@ class TrajDataset(Dataset):
return self._len
def __getitem__(self, item):
if self.mode.lower() == 'just_route':
trajectory = self.map.get_random_trajectory()
label = choice([0, 1])
blank_trajectory_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
return (map_array, blank_trajectory_space), label
while True:
trajectory = self.map.get_random_trajectory()
# TODO: Sanity Check this while true loop...
alternative = self.map.generate_alternative(trajectory)
label = self.map.are_homotopic(trajectory, alternative)
@ -42,18 +55,26 @@ class TrajDataset(Dataset):
break
self.last_label = label
if self.all_in_map:
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
blank_trajectory_space = torch.zeros(self.map.shape)
blank_alternative_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
for index in alternative.vertices:
blank_alternative_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
if self.mode == 'separated_arrays':
return (map_array, blank_trajectory_space, int(label)), blank_alternative_space
else:
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
elif self.mode == 'vectors':
return trajectory.vertices, alternative.vertices, label, self.mapname
else:
raise ValueError
class TrajData(object):
@property
@ -64,7 +85,7 @@ class TrajData(object):
def map_shapes_max(self):
shapes = self.map_shapes
shape_list = list(map(max, zip(*shapes)))
if self.all_in_map:
if self.mode == 'all_in_map':
shape_list[0] += 2
return shape_list
@ -72,10 +93,10 @@ class TrajData(object):
def name(self):
return self.__class__.__name__
def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_):
def __init__(self, map_root, length=100000, mode='separated_arrays', **_):
self.all_in_map = all_in_map
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
self.mode = mode
self.maps_root = Path(map_root)
self.length = length
self._dataset = self._load_datasets()
@ -86,8 +107,8 @@ class TrajData(object):
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
all_in_map=self.all_in_map, embedding_size=max_map_size,
preserve_equal_samples=False)
mode=self.mode, embedding_size=max_map_size,
preserve_equal_samples=True)
for map_file in map_files])
@property

@ -6,27 +6,29 @@ class ROCEvaluation(object):
linewidth = 2
def __init__(self, prepare_figure=False):
self.prepare_figure = prepare_figure
def __init__(self, plot_roc=False):
self.plot_roc = plot_roc
self.epoch = 0
def __call__(self, prediction, label, plotting=False):
def __call__(self, prediction, label):
# Compute ROC curve and ROC area
fpr, tpr, _ = roc_curve(prediction, label)
roc_auc = auc(fpr, tpr)
if plotting:
fig = plt.gcf()
fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
if self.plot_roc:
_ = plt.gcf()
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
self._prepare_fig()
return roc_auc, fpr, tpr
def _prepare_fig(self):
fig = plt.gcf()
fig.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
fig.xlim([0.0, 1.0])
fig.ylim([0.0, 1.05])
fig.xlabel('False Positive Rate')
fig.ylabel('True Positive Rate')
ax = plt.gca()
plt.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
fig.legend(loc="lower right")
return fig

@ -1,6 +1,13 @@
from datasets.paired_dataset import TrajPairData
from lib.modules.blocks import ConvModule
from lib.modules.utils import LightningBaseModule
import torch
from functools import reduce
from operator import mul
from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
class CNNRouteGeneratorModel(LightningBaseModule):
@ -8,36 +15,169 @@ class CNNRouteGeneratorModel(LightningBaseModule):
name = 'CNNRouteGenerator'
def configure_optimizers(self):
pass
def validation_step(self, *args, **kwargs):
pass
def validation_end(self, outputs):
pass
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
pass
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def test_step(self, *args, **kwargs):
pass
@property
def discriminator(self):
if self._disc is None:
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
return self._disc
def set_discriminator(self, disc_model):
if self._disc is not None:
raise RuntimeError('Discriminator has already been set... What are trying to do?')
self._disc = disc_model
def __init__(self, *params):
super(CNNRouteGeneratorModel, self).__init__(*params)
# Dataset
self.dataset = TrajPairData(self.hparams.data_param.data_root)
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route')
# Additional Attributes
self.in_shape = self.dataset.map_shapes_max
# Todo: Better naming and size in Parameters
self.feature_dim = 10
self._disc = None
# NN Nodes
###################################################
#
# Utils
self.relu = nn.ReLU()
self.criterion = nn.MSELoss()
self.conv2 = ConvModule(self.conv1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.conv3 = ConvModule(self.conv2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
#
# Map Encoder
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
conv_filters=self.hparams.model_param.filters[0])
def forward(self, x):
pass
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[0])
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1])
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[1])
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2])
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[2])
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2]*2)
self.map_flat = Flatten(self.map_conv_3.shape)
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
#
# Trajectory Encoder
self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_flat = Flatten(self.traj_conv_3.shape)
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
#
# Variational Bottleneck
self.mu = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
self.logvar = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
#
# Alternative Generator
self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, self.traj_conv_3.shape))
self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape)
self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=5, conv_stride=1)
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
conv_padding=0, conv_kernel=3, conv_stride=1)
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
conv_padding=1, conv_kernel=3, conv_stride=1)
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
conv_padding=1, conv_kernel=3, conv_stride=1)
def forward(self, batch_x):
#
# Sorting the Input
map_array, trajectory, label = batch_x
#
# Encode
map_tensor = self.map_conv_0(map_array)
map_tensor = self.map_res_1(map_tensor)
map_tensor = self.map_conv_1(map_tensor)
map_tensor = self.map_res_2(map_tensor)
map_tensor = self.map_conv_2(map_tensor)
map_tensor = self.map_res_3(map_tensor)
map_tensor = self.map_conv_3(map_tensor)
map_tensor = self.map_flat(map_tensor)
map_tensor = self.map_lin(map_tensor)
traj_tensor = self.traj_conv_1(trajectory)
traj_tensor = self.traj_conv_2(traj_tensor)
traj_tensor = self.traj_conv_3(traj_tensor)
traj_tensor = self.traj_flat(traj_tensor)
traj_tensor = self.traj_lin(traj_tensor)
mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
mixed_tensor = self.relu(mixed_tensor)
#
# Parameter and Sampling
mu = self.mu(mixed_tensor)
logvar = self.logvar(mixed_tensor)
z = self.reparameterize(mu, logvar)
#
# Generate
alt_tensor = self.alt_lin_1(z)
alt_tensor = self.alt_lin_2(alt_tensor)
alt_tensor = self.reshape_to_map(alt_tensor)
alt_tensor = self.alt_deconv_1(alt_tensor)
alt_tensor = self.alt_deconv_2(alt_tensor)
alt_tensor = self.alt_deconv_3(alt_tensor)
alt_tensor = self.alt_deconv_out(alt_tensor)
return alt_tensor, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std

@ -1,14 +1,16 @@
from lib.modules.blocks import LightningBaseModule
from lib.modules.losses import BinaryHomotopicLoss
from lib.modules.utils import LightningBaseModule
from lib.objects.map import Map
from lib.objects.trajectory import Trajectory
import torch.nn as nn
nn.MSELoss
class LinearRouteGeneratorModel(LightningBaseModule):
def test_epoch_end(self, outputs):
pass
name = 'LinearRouteGenerator'
def configure_optimizers(self):
@ -33,6 +35,12 @@ class LinearRouteGeneratorModel(LightningBaseModule):
pred_y = self(map_x, traj_x, label_x)
loss = self.loss(traj_x, pred_y)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
return dict(loss=loss, log=dict(loss=loss))
def test_step(self, *args, **kwargs):
@ -41,7 +49,7 @@ class LinearRouteGeneratorModel(LightningBaseModule):
def __init__(self, *params):
super(LinearRouteGeneratorModel, self).__init__(*params)
self.loss = BinaryHomotopicLoss(self.map_storage)
self.criterion = BinaryHomotopicLoss(self.map_storage)
def forward(self, map_x, traj_x, label_x):
pass

@ -24,41 +24,44 @@ class ConvHomDetector(LightningBaseModule):
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
loss = F.binary_cross_entropy(pred_y, batch_y.float())
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
return {'loss': loss, 'log': dict(loss=loss)}
def test_step(self, batch_xy, **kwargs):
def test_step(self, batch_xy, batch_nb, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
return dict(prediction=pred_y, label=batch_y)
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
def test_end(self, outputs):
evaluation = ROCEvaluation()
predictions = torch.stack([x['prediction'] for x in outputs])
labels = torch.stack([x['label'] for x in outputs])
def test_epoch_end(self, outputs):
evaluation = ROCEvaluation(plot_roc=True)
predictions = torch.cat([x['prediction'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
scores = evaluation(predictions.numpy(), labels.numpy(), )
self.logger.log_metrics({key:value for key, value in zip(['roc_auc', 'tpr', 'fpr'], scores)})
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
score_dict = dict(roc_auc=roc_auc)
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}', plt.gcf())
pass
def __init__(self, *params):
super(ConvHomDetector, self).__init__(*params)
return dict(log=score_dict)
def __init__(self, hparams):
super(ConvHomDetector, self).__init__(hparams)
# Dataset
self.dataset = TrajData(self.hparams.data_param.root)
self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map')
# Additional Attributes
self.map_shape = self.dataset.map_shapes_max
# Model Paramters
# Model Parameters
self.in_shape = self.dataset.map_shapes_max
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
self.criterion = nn.BCEWithLogitsLoss()
# NN Nodes
# ============================
# Convolutional Map Processing
#
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3,
@ -86,7 +89,6 @@ class ConvHomDetector(LightningBaseModule):
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
# Comments on Multi Class labels
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
self.out_activation = nn.Sigmoid() # nn.Softmax
def forward(self, x):
tensor = self.map_conv_0(x)
@ -98,25 +100,4 @@ class ConvHomDetector(LightningBaseModule):
tensor = self.flatten(tensor)
tensor = self.linear(tensor)
tensor = self.classifier(tensor)
tensor = self.out_activation(tensor)
return tensor
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)

@ -1,11 +1,7 @@
from abc import ABC
from pathlib import Path
from typing import Union
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from lib.modules.utils import AutoPad, Interpolate
#
@ -26,12 +22,12 @@ class ConvModule(nn.Module):
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__()
# Module Paramters
# Module Parameters
self.in_shape = in_shape
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.activation = activation()
# Convolution Paramters
# Convolution Parameters
self.padding = conv_padding
self.stride = conv_stride
@ -44,7 +40,7 @@ class ConvModule(nn.Module):
)
def forward(self, x):
x = self.norm(x) if self.norm else x
x = self.norm(x)
tensor = self.conv(x)
tensor = self.dropout(tensor)
@ -100,13 +96,13 @@ class ResidualModule(nn.Module):
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, module_class, n, activation=None, **module_paramters):
def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
assert n >= 1
super(ResidualModule, self).__init__()
self.in_shape = in_shape
module_paramters.update(in_shape=in_shape)
module_parameters.update(in_shape=in_shape)
self.activation = activation() if activation else lambda x: x
self.residual_block = nn.ModuleList([module_class(**module_paramters) for _ in range(n)])
self.residual_block = nn.ModuleList([module_class(**module_parameters) for _ in range(n)])
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x):
@ -143,5 +139,3 @@ class RecurrentModule(nn.Module):
def forward(self, x):
tensor = self.rnn(x)
return tensor

@ -1,8 +1,11 @@
from typing import List
import torch
from torch import nn
from lib.modules.utils import FlipTensor
from lib.objects.map import MapStorage
from lib.objects.map import MapStorage, Map
from lib.objects.trajectory import Trajectory
class BinaryHomotopicLoss(nn.Module):
@ -12,6 +15,9 @@ class BinaryHomotopicLoss(nn.Module):
self.flipper = FlipTensor()
def forward(self, x: torch.Tensor, y: torch.Tensor, mapnames: str):
y_flipepd = self.flipper(y)
circle = torch.cat((x, y_flipepd), dim=-1)
masp = self.map_storage[mapnames].are
maps: List[Map] = [self.map_storage[mapname] for mapname in mapnames]
for basemap in maps:
basemap = basemap.as_2d_array

@ -83,9 +83,9 @@ class LightningBaseModule(pl.LightningModule, ABC):
print(e)
return -1
def __init__(self, params):
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
self.hparams = params
self.hparams = hparams
# Data loading
# =============================================================================
@ -109,6 +109,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
def data_len(self):
return len(self.dataset.train_dataset)
@property
def n_train_batches(self):
return len(self.train_dataloader())
def configure_optimizers(self):
raise NotImplementedError
@ -121,7 +125,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_end(self, outputs):
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self):
@ -134,6 +138,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
m.bias.data.fill_(0.01)
self.apply(_weight_init)
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
class FilterLayer(nn.Module):

@ -12,6 +12,7 @@ import networkx as nx
from matplotlib import pyplot as plt
from lib.objects.trajectory import Trajectory
import lib.variables as V
class Map(object):
@ -145,14 +146,14 @@ class Map(object):
img = Image.new('L', (self.height, self.width), 0)
draw = ImageDraw.Draw(img)
draw.polygon(polyline, outline=1, fill=1)
draw.polygon(polyline, outline=self.white, fill=self.white)
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.white, 1, 0)).sum()
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
if a:
return False # Non-Homotoph
return V.ALTERNATIVE # Non-Homotoph
else:
return True # Homotoph
return V.HOMOTOPIC # Homotoph
def draw(self):
fig, ax = plt.gcf(), plt.gca()

@ -6,6 +6,7 @@ from lib import variables as V
import numpy as np
class Trajectory(object):
@property
@ -57,7 +58,8 @@ class Trajectory(object):
def draw(self, highlights=True, label=None, **kwargs):
if label is not None:
kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative')
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative',
lw=1)
if highlights:
kwargs.update(marker='o')
fig, ax = plt.gcf(), plt.gca()

@ -5,6 +5,7 @@ from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
from lib.utils.model_io import ModelParameters
@ -27,7 +28,7 @@ class Config(ConfigParser):
@property
def model_class(self):
model_dict = dict(classifier_cnn=ConvHomDetector)
model_dict = dict(classifier_cnn=ConvHomDetector, generator_cnn=CNNRouteGeneratorModel)
try:
return model_dict[self.get('model', 'type')]
except KeyError as e:

@ -1,8 +1,8 @@
from pathlib import Path
from pytorch_lightning.logging.base import LightningLoggerBase
from pytorch_lightning.logging.neptune import NeptuneLogger
from pytorch_lightning.logging.test_tube import TestTubeLogger
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger
from lib.utils.config import Config

@ -1,5 +1,7 @@
from argparse import Namespace
from pathlib import Path
import torch
from natsort import natsorted
from torch import nn
@ -35,30 +37,25 @@ class ModelParameters(Namespace):
class SavedLightningModels(object):
@classmethod
def load_checkpoint(cls, models_root_path, model, n=-1, tags_file_path=''):
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
if model is None:
model = torch.load(models_root_path / 'model_class.obj')
assert model is not None
if not tags_file_path:
tag_files = models_root_path.rglob('meta_tags.csv')
tags_file_path = list(tag_files)[0]
return cls(weights=found_checkpoints[n], model=model, tags=tags_file_path)
return cls(weights=found_checkpoints[n], model=model)
def __init__(self, **kwargs):
self.weights: str = kwargs.get('weights', '')
self.tags: str = kwargs.get('tags', '')
self.model = kwargs.get('model', None)
assert self.model is not None
def restore(self):
pretrained_model = self.model.load_from_metrics(
weights_path=self.weights,
tags_csv=self.tags
)
pretrained_model = self.model.load_from_checkpoint(self.weights)
pretrained_model.eval()
pretrained_model.freeze()
return pretrained_model

@ -1,5 +1,5 @@
from pathlib import Path
_ROOT = Path('..')
HOMOTOPIC = 0
ALTERNATIVE = 1
HOMOTOPIC = 1
ALTERNATIVE = 0

26
main.py

@ -10,13 +10,11 @@ import warnings
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from lib.modules.utils import LightningBaseModule
from lib.utils.config import Config
from lib.utils.logging import Logger
from lib.evaluation.classification import ROCEvaluation
from lib.utils.model_io import SavedLightningModels
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -36,8 +34,8 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
# Transformations
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
@ -50,10 +48,11 @@ main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
# Model
main_arg_parser.add_argument("--model_type", type=str, default="classifier_cnn", help="")
main_arg_parser.add_argument("--model_type", type=str, default="generator_cnn", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=2, help="")
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
@ -78,8 +77,9 @@ def run_lightning_loop(config_obj):
# Checkpoint Saving
checkpoint_callback = ModelCheckpoint(
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True, save_top_k=5,
verbose=True, save_top_k=0,
)
# =============================================================================
# Early Stopping
# TODO: For This to work, one must set a validation step and End Eval and Score
@ -94,6 +94,11 @@ def run_lightning_loop(config_obj):
# Init
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
model.init_weights()
if model.name == 'CNNRouteGenerator':
# ToDo: Make this dependent on the used seed
path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'version_0')
disc_model = SavedLightningModels.load_checkpoint(path).restore()
model.set_discriminator(disc_model)
# Trainer
# =============================================================================
@ -101,8 +106,8 @@ def run_lightning_loop(config_obj):
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
row_log_interval=(model.data_len * 0.01), # TODO: Better Value / Setting
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,
logger=logger,
fast_dev_run=config_obj.main.debug,
@ -110,7 +115,7 @@ def run_lightning_loop(config_obj):
)
# Train It
trainer.fit(model,)
trainer.fit(model)
# Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
@ -118,6 +123,7 @@ def run_lightning_loop(config_obj):
# Evaluate It
trainer.test()
return model

@ -16,7 +16,7 @@ if __name__ == '__main__':
# Model Settings
config = Config().read_namespace(args)
# use_bias, activation, model, use_norm, max_epochs, filters
cnn_classifier = dict(train_epochs=100, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters