project Refactor, CNN Classifier Basics

This commit is contained in:
Steffen Illium
2020-03-08 23:46:02 +01:00
parent 75e8a61628
commit cd4fdf2de3
20 changed files with 441 additions and 239 deletions

7
.idea/deployment.xml generated
View File

@ -2,6 +2,13 @@
<project version="4"> <project version="4">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22" showAutoUploadSettingsWarning="false"> <component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22" showAutoUploadSettingsWarning="false">
<serverData> <serverData>
<paths name="erlowa@aimachine">
<serverdata>
<mappings>
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="steffen@aimachine:22"> <paths name="steffen@aimachine:22">
<serverdata> <serverdata>
<mappings> <mappings>

View File

@ -1,13 +1,22 @@
<component name="ProjectDictionaryState"> <component name="ProjectDictionaryState">
<dictionary name="steffen"> <dictionary name="steffen">
<words> <words>
<w>autopad</w>
<w>conv</w> <w>conv</w>
<w>convolutional</w>
<w>dataloader</w> <w>dataloader</w>
<w>dataloaders</w>
<w>datasets</w>
<w>homotopic</w> <w>homotopic</w>
<w>hparams</w> <w>hparams</w>
<w>hyperparamter</w> <w>hyperparamter</w>
<w>kingma</w>
<w>logvar</w>
<w>mapname</w>
<w>mapnames</w>
<w>numlayers</w> <w>numlayers</w>
<w>reparameterize</w> <w>reparameterize</w>
<w>softmax</w>
<w>traj</w> <w>traj</w>
</words> </words>
</dictionary> </dictionary>

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="hom_traj_gen@aimachine" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
</module> </module>

5
.idea/misc.xml generated
View File

@ -3,5 +3,8 @@
<component name="JavaScriptSettings"> <component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" /> <option name="languageLevel" value="ES6" />
</component> </component>
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@ai-machine" project-jdk-type="Python SDK" />
<component name="PyPackaging">
<option name="earlyReleasesAsUpgrades" value="true" />
</component>
</project> </project>

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import Union, List from typing import Union, List
import torch import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map from lib.objects.map import Map
@ -16,10 +17,11 @@ class TrajDataset(Dataset):
return self.map.as_array.shape return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs): length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs):
super(TrajDataset, self).__init__() super(TrajDataset, self).__init__()
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
self.preserve_equal_samples = preserve_equal_samples self.preserve_equal_samples = preserve_equal_samples
self.all_in_map = all_in_map self.mode = mode
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root self.maps_root = maps_root
self._len = length self._len = length
@ -31,8 +33,19 @@ class TrajDataset(Dataset):
return self._len return self._len
def __getitem__(self, item): def __getitem__(self, item):
if self.mode.lower() == 'just_route':
trajectory = self.map.get_random_trajectory() trajectory = self.map.get_random_trajectory()
label = choice([0, 1])
blank_trajectory_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
return (map_array, blank_trajectory_space), label
while True: while True:
trajectory = self.map.get_random_trajectory()
# TODO: Sanity Check this while true loop... # TODO: Sanity Check this while true loop...
alternative = self.map.generate_alternative(trajectory) alternative = self.map.generate_alternative(trajectory)
label = self.map.are_homotopic(trajectory, alternative) label = self.map.are_homotopic(trajectory, alternative)
@ -42,18 +55,26 @@ class TrajDataset(Dataset):
break break
self.last_label = label self.last_label = label
if self.all_in_map: if self.mode.lower() in ['all_in_map', 'separated_arrays']:
blank_trajectory_space = torch.zeros(self.map.shape) blank_trajectory_space = torch.zeros(self.map.shape)
blank_alternative_space = torch.zeros(self.map.shape) blank_alternative_space = torch.zeros(self.map.shape)
for index in trajectory.vertices: for index in trajectory.vertices:
blank_trajectory_space[index] = 1 blank_trajectory_space[index] = 1
for index in alternative.vertices:
blank_alternative_space[index] = 1 blank_alternative_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float() map_array = torch.as_tensor(self.map.as_array).float()
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label) if self.mode == 'separated_arrays':
return (map_array, blank_trajectory_space, int(label)), blank_alternative_space
else: else:
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
elif self.mode == 'vectors':
return trajectory.vertices, alternative.vertices, label, self.mapname return trajectory.vertices, alternative.vertices, label, self.mapname
else:
raise ValueError
class TrajData(object): class TrajData(object):
@property @property
@ -64,7 +85,7 @@ class TrajData(object):
def map_shapes_max(self): def map_shapes_max(self):
shapes = self.map_shapes shapes = self.map_shapes
shape_list = list(map(max, zip(*shapes))) shape_list = list(map(max, zip(*shapes)))
if self.all_in_map: if self.mode == 'all_in_map':
shape_list[0] += 2 shape_list[0] += 2
return shape_list return shape_list
@ -72,10 +93,10 @@ class TrajData(object):
def name(self): def name(self):
return self.__class__.__name__ return self.__class__.__name__
def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_): def __init__(self, map_root, length=100000, mode='separated_arrays', **_):
self.all_in_map = all_in_map self.mode = mode
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps' self.maps_root = Path(map_root)
self.length = length self.length = length
self._dataset = self._load_datasets() self._dataset = self._load_datasets()
@ -86,8 +107,8 @@ class TrajData(object):
# find max image size among available maps: # find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files])))) max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
all_in_map=self.all_in_map, embedding_size=max_map_size, mode=self.mode, embedding_size=max_map_size,
preserve_equal_samples=False) preserve_equal_samples=True)
for map_file in map_files]) for map_file in map_files])
@property @property

View File

@ -6,27 +6,29 @@ class ROCEvaluation(object):
linewidth = 2 linewidth = 2
def __init__(self, prepare_figure=False): def __init__(self, plot_roc=False):
self.prepare_figure = prepare_figure self.plot_roc = plot_roc
self.epoch = 0 self.epoch = 0
def __call__(self, prediction, label, plotting=False): def __call__(self, prediction, label):
# Compute ROC curve and ROC area # Compute ROC curve and ROC area
fpr, tpr, _ = roc_curve(prediction, label) fpr, tpr, _ = roc_curve(prediction, label)
roc_auc = auc(fpr, tpr) roc_auc = auc(fpr, tpr)
if plotting: if self.plot_roc:
fig = plt.gcf() _ = plt.gcf()
fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})') plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
self._prepare_fig()
return roc_auc, fpr, tpr return roc_auc, fpr, tpr
def _prepare_fig(self): def _prepare_fig(self):
fig = plt.gcf() fig = plt.gcf()
fig.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--') ax = plt.gca()
fig.xlim([0.0, 1.0]) plt.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
fig.ylim([0.0, 1.05]) plt.xlim([0.0, 1.0])
fig.xlabel('False Positive Rate') plt.ylim([0.0, 1.05])
fig.ylabel('True Positive Rate') plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
fig.legend(loc="lower right") fig.legend(loc="lower right")
return fig return fig

View File

@ -1,6 +1,13 @@
from datasets.paired_dataset import TrajPairData import torch
from lib.modules.blocks import ConvModule from functools import reduce
from lib.modules.utils import LightningBaseModule from operator import mul
from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
class CNNRouteGeneratorModel(LightningBaseModule): class CNNRouteGeneratorModel(LightningBaseModule):
@ -8,36 +15,169 @@ class CNNRouteGeneratorModel(LightningBaseModule):
name = 'CNNRouteGenerator' name = 'CNNRouteGenerator'
def configure_optimizers(self): def configure_optimizers(self):
pass return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def validation_step(self, *args, **kwargs):
pass
def validation_end(self, outputs):
pass
def training_step(self, batch_xy, batch_nb, *args, **kwargs): def training_step(self, batch_xy, batch_nb, *args, **kwargs):
pass batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def test_step(self, *args, **kwargs): def test_step(self, *args, **kwargs):
pass pass
@property
def discriminator(self):
if self._disc is None:
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
return self._disc
def set_discriminator(self, disc_model):
if self._disc is not None:
raise RuntimeError('Discriminator has already been set... What are trying to do?')
self._disc = disc_model
def __init__(self, *params): def __init__(self, *params):
super(CNNRouteGeneratorModel, self).__init__(*params) super(CNNRouteGeneratorModel, self).__init__(*params)
# Dataset # Dataset
self.dataset = TrajPairData(self.hparams.data_param.data_root) self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route')
# Additional Attributes # Additional Attributes
self.in_shape = self.dataset.map_shapes_max self.in_shape = self.dataset.map_shapes_max
# Todo: Better naming and size in Parameters
self.feature_dim = 10
self._disc = None
# NN Nodes # NN Nodes
###################################################
#
# Utils
self.relu = nn.ReLU()
self.criterion = nn.MSELoss()
#
self.conv2 = ConvModule(self.conv1.shape, conv_kernel=3, conv_stride=1, conv_padding=0, # Map Encoder
conv_filters=self.hparams.model_param.filters[0]) self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
self.conv3 = ConvModule(self.conv2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0]) conv_filters=self.hparams.model_param.filters[0])
def forward(self, x): self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
pass conv_padding=1, conv_filters=self.hparams.model_param.filters[0])
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1])
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[1])
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2])
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[2])
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2]*2)
self.map_flat = Flatten(self.map_conv_3.shape)
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
#
# Trajectory Encoder
self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_flat = Flatten(self.traj_conv_3.shape)
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
#
# Variational Bottleneck
self.mu = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
self.logvar = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
#
# Alternative Generator
self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, self.traj_conv_3.shape))
self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape)
self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=5, conv_stride=1)
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
conv_padding=0, conv_kernel=3, conv_stride=1)
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
conv_padding=1, conv_kernel=3, conv_stride=1)
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
conv_padding=1, conv_kernel=3, conv_stride=1)
def forward(self, batch_x):
#
# Sorting the Input
map_array, trajectory, label = batch_x
#
# Encode
map_tensor = self.map_conv_0(map_array)
map_tensor = self.map_res_1(map_tensor)
map_tensor = self.map_conv_1(map_tensor)
map_tensor = self.map_res_2(map_tensor)
map_tensor = self.map_conv_2(map_tensor)
map_tensor = self.map_res_3(map_tensor)
map_tensor = self.map_conv_3(map_tensor)
map_tensor = self.map_flat(map_tensor)
map_tensor = self.map_lin(map_tensor)
traj_tensor = self.traj_conv_1(trajectory)
traj_tensor = self.traj_conv_2(traj_tensor)
traj_tensor = self.traj_conv_3(traj_tensor)
traj_tensor = self.traj_flat(traj_tensor)
traj_tensor = self.traj_lin(traj_tensor)
mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
mixed_tensor = self.relu(mixed_tensor)
#
# Parameter and Sampling
mu = self.mu(mixed_tensor)
logvar = self.logvar(mixed_tensor)
z = self.reparameterize(mu, logvar)
#
# Generate
alt_tensor = self.alt_lin_1(z)
alt_tensor = self.alt_lin_2(alt_tensor)
alt_tensor = self.reshape_to_map(alt_tensor)
alt_tensor = self.alt_deconv_1(alt_tensor)
alt_tensor = self.alt_deconv_2(alt_tensor)
alt_tensor = self.alt_deconv_3(alt_tensor)
alt_tensor = self.alt_deconv_out(alt_tensor)
return alt_tensor, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std

View File

@ -1,14 +1,16 @@
from lib.modules.blocks import LightningBaseModule
from lib.modules.losses import BinaryHomotopicLoss from lib.modules.losses import BinaryHomotopicLoss
from lib.modules.utils import LightningBaseModule
from lib.objects.map import Map from lib.objects.map import Map
from lib.objects.trajectory import Trajectory from lib.objects.trajectory import Trajectory
import torch.nn as nn import torch.nn as nn
nn.MSELoss
class LinearRouteGeneratorModel(LightningBaseModule): class LinearRouteGeneratorModel(LightningBaseModule):
def test_epoch_end(self, outputs):
pass
name = 'LinearRouteGenerator' name = 'LinearRouteGenerator'
def configure_optimizers(self): def configure_optimizers(self):
@ -33,6 +35,12 @@ class LinearRouteGeneratorModel(LightningBaseModule):
pred_y = self(map_x, traj_x, label_x) pred_y = self(map_x, traj_x, label_x)
loss = self.loss(traj_x, pred_y) loss = self.loss(traj_x, pred_y)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
return dict(loss=loss, log=dict(loss=loss)) return dict(loss=loss, log=dict(loss=loss))
def test_step(self, *args, **kwargs): def test_step(self, *args, **kwargs):
@ -41,7 +49,7 @@ class LinearRouteGeneratorModel(LightningBaseModule):
def __init__(self, *params): def __init__(self, *params):
super(LinearRouteGeneratorModel, self).__init__(*params) super(LinearRouteGeneratorModel, self).__init__(*params)
self.loss = BinaryHomotopicLoss(self.map_storage) self.criterion = BinaryHomotopicLoss(self.map_storage)
def forward(self, map_x, traj_x, label_x): def forward(self, map_x, traj_x, label_x):
pass pass

View File

@ -24,41 +24,44 @@ class ConvHomDetector(LightningBaseModule):
def training_step(self, batch_xy, batch_nb, *args, **kwargs): def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy batch_x, batch_y = batch_xy
pred_y = self(batch_x) pred_y = self(batch_x)
loss = F.binary_cross_entropy(pred_y, batch_y.float()) loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
return {'loss': loss, 'log': dict(loss=loss)} return {'loss': loss, 'log': dict(loss=loss)}
def test_step(self, batch_xy, **kwargs): def test_step(self, batch_xy, batch_nb, **kwargs):
batch_x, batch_y = batch_xy batch_x, batch_y = batch_xy
pred_y = self(batch_x) pred_y = self(batch_x)
return dict(prediction=pred_y, label=batch_y) return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
def test_end(self, outputs): def test_epoch_end(self, outputs):
evaluation = ROCEvaluation() evaluation = ROCEvaluation(plot_roc=True)
predictions = torch.stack([x['prediction'] for x in outputs]) predictions = torch.cat([x['prediction'] for x in outputs])
labels = torch.stack([x['label'] for x in outputs]) labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
scores = evaluation(predictions.numpy(), labels.numpy(), ) # Sci-py call ROC eval call is eval(true_label, prediction)
self.logger.log_metrics({key:value for key, value in zip(['roc_auc', 'tpr', 'fpr'], scores)}) roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
score_dict = dict(roc_auc=roc_auc)
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}', plt.gcf()) self.logger.log_image(f'{self.name}', plt.gcf())
pass
def __init__(self, *params): return dict(log=score_dict)
super(ConvHomDetector, self).__init__(*params)
def __init__(self, hparams):
super(ConvHomDetector, self).__init__(hparams)
# Dataset # Dataset
self.dataset = TrajData(self.hparams.data_param.root) self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map')
# Additional Attributes # Additional Attributes
self.map_shape = self.dataset.map_shapes_max self.map_shape = self.dataset.map_shapes_max
# Model Paramters # Model Parameters
self.in_shape = self.dataset.map_shapes_max self.in_shape = self.dataset.map_shapes_max
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}' assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
self.criterion = nn.BCEWithLogitsLoss()
# NN Nodes # NN Nodes
# ============================ # ============================
# Convolutional Map Processing # Convolutional Map Processing
#
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3, self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3,
@ -86,7 +89,6 @@ class ConvHomDetector(LightningBaseModule):
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10) self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
# Comments on Multi Class labels # Comments on Multi Class labels
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes) self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
self.out_activation = nn.Sigmoid() # nn.Softmax
def forward(self, x): def forward(self, x):
tensor = self.map_conv_0(x) tensor = self.map_conv_0(x)
@ -98,25 +100,4 @@ class ConvHomDetector(LightningBaseModule):
tensor = self.flatten(tensor) tensor = self.flatten(tensor)
tensor = self.linear(tensor) tensor = self.linear(tensor)
tensor = self.classifier(tensor) tensor = self.classifier(tensor)
tensor = self.out_activation(tensor)
return tensor return tensor
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)

View File

@ -1,11 +1,7 @@
from abc import ABC
from pathlib import Path
from typing import Union from typing import Union
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from lib.modules.utils import AutoPad, Interpolate from lib.modules.utils import AutoPad, Interpolate
# #
@ -26,12 +22,12 @@ class ConvModule(nn.Module):
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0): conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__() super(ConvModule, self).__init__()
# Module Paramters # Module Parameters
self.in_shape = in_shape self.in_shape = in_shape
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2] in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.activation = activation() self.activation = activation()
# Convolution Paramters # Convolution Parameters
self.padding = conv_padding self.padding = conv_padding
self.stride = conv_stride self.stride = conv_stride
@ -44,7 +40,7 @@ class ConvModule(nn.Module):
) )
def forward(self, x): def forward(self, x):
x = self.norm(x) if self.norm else x x = self.norm(x)
tensor = self.conv(x) tensor = self.conv(x)
tensor = self.dropout(tensor) tensor = self.dropout(tensor)
@ -100,13 +96,13 @@ class ResidualModule(nn.Module):
output = self(x) output = self(x)
return output.shape[1:] return output.shape[1:]
def __init__(self, in_shape, module_class, n, activation=None, **module_paramters): def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
assert n >= 1 assert n >= 1
super(ResidualModule, self).__init__() super(ResidualModule, self).__init__()
self.in_shape = in_shape self.in_shape = in_shape
module_paramters.update(in_shape=in_shape) module_parameters.update(in_shape=in_shape)
self.activation = activation() if activation else lambda x: x self.activation = activation() if activation else lambda x: x
self.residual_block = nn.ModuleList([module_class(**module_paramters) for _ in range(n)]) self.residual_block = nn.ModuleList([module_class(**module_parameters) for _ in range(n)])
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.' assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x): def forward(self, x):
@ -143,5 +139,3 @@ class RecurrentModule(nn.Module):
def forward(self, x): def forward(self, x):
tensor = self.rnn(x) tensor = self.rnn(x)
return tensor return tensor

View File

@ -1,8 +1,11 @@
from typing import List
import torch import torch
from torch import nn from torch import nn
from lib.modules.utils import FlipTensor from lib.modules.utils import FlipTensor
from lib.objects.map import MapStorage from lib.objects.map import MapStorage, Map
from lib.objects.trajectory import Trajectory
class BinaryHomotopicLoss(nn.Module): class BinaryHomotopicLoss(nn.Module):
@ -11,7 +14,10 @@ class BinaryHomotopicLoss(nn.Module):
self.map_storage = map_storage self.map_storage = map_storage
self.flipper = FlipTensor() self.flipper = FlipTensor()
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str): def forward(self, x: torch.Tensor, y: torch.Tensor, mapnames: str):
y_flipepd = self.flipper(y) maps: List[Map] = [self.map_storage[mapname] for mapname in mapnames]
circle = torch.cat((x, y_flipepd), dim=-1) for basemap in maps:
masp = self.map_storage[mapnames].are basemap = basemap.as_2d_array

View File

@ -83,9 +83,9 @@ class LightningBaseModule(pl.LightningModule, ABC):
print(e) print(e)
return -1 return -1
def __init__(self, params): def __init__(self, hparams):
super(LightningBaseModule, self).__init__() super(LightningBaseModule, self).__init__()
self.hparams = params self.hparams = hparams
# Data loading # Data loading
# ============================================================================= # =============================================================================
@ -109,6 +109,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
def data_len(self): def data_len(self):
return len(self.dataset.train_dataset) return len(self.dataset.train_dataset)
@property
def n_train_batches(self):
return len(self.train_dataloader())
def configure_optimizers(self): def configure_optimizers(self):
raise NotImplementedError raise NotImplementedError
@ -121,7 +125,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
def test_step(self, *args, **kwargs): def test_step(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def test_end(self, outputs): def test_epoch_end(self, outputs):
raise NotImplementedError raise NotImplementedError
def init_weights(self): def init_weights(self):
@ -134,6 +138,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
m.bias.data.fill_(0.01) m.bias.data.fill_(0.01)
self.apply(_weight_init) self.apply(_weight_init)
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
class FilterLayer(nn.Module): class FilterLayer(nn.Module):

View File

@ -12,6 +12,7 @@ import networkx as nx
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from lib.objects.trajectory import Trajectory from lib.objects.trajectory import Trajectory
import lib.variables as V
class Map(object): class Map(object):
@ -145,14 +146,14 @@ class Map(object):
img = Image.new('L', (self.height, self.width), 0) img = Image.new('L', (self.height, self.width), 0)
draw = ImageDraw.Draw(img) draw = ImageDraw.Draw(img)
draw.polygon(polyline, outline=1, fill=1) draw.polygon(polyline, outline=self.white, fill=self.white)
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.white, 1, 0)).sum() a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
if a: if a:
return False # Non-Homotoph return V.ALTERNATIVE # Non-Homotoph
else: else:
return True # Homotoph return V.HOMOTOPIC # Homotoph
def draw(self): def draw(self):
fig, ax = plt.gcf(), plt.gca() fig, ax = plt.gcf(), plt.gca()

View File

@ -6,6 +6,7 @@ from lib import variables as V
import numpy as np import numpy as np
class Trajectory(object): class Trajectory(object):
@property @property
@ -57,7 +58,8 @@ class Trajectory(object):
def draw(self, highlights=True, label=None, **kwargs): def draw(self, highlights=True, label=None, **kwargs):
if label is not None: if label is not None:
kwargs.update(color='red' if label == V.HOMOTOPIC else 'green', kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative') label='Homotopic' if label == V.HOMOTOPIC else 'Alternative',
lw=1)
if highlights: if highlights:
kwargs.update(marker='o') kwargs.update(marker='o')
fig, ax = plt.gcf(), plt.gca() fig, ax = plt.gcf(), plt.gca()

View File

@ -5,6 +5,7 @@ from collections import defaultdict
from configparser import ConfigParser from configparser import ConfigParser
from pathlib import Path from pathlib import Path
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.models.homotopy_classification.cnn_based import ConvHomDetector from lib.models.homotopy_classification.cnn_based import ConvHomDetector
from lib.utils.model_io import ModelParameters from lib.utils.model_io import ModelParameters
@ -27,7 +28,7 @@ class Config(ConfigParser):
@property @property
def model_class(self): def model_class(self):
model_dict = dict(classifier_cnn=ConvHomDetector) model_dict = dict(classifier_cnn=ConvHomDetector, generator_cnn=CNNRouteGeneratorModel)
try: try:
return model_dict[self.get('model', 'type')] return model_dict[self.get('model', 'type')]
except KeyError as e: except KeyError as e:

View File

@ -1,8 +1,8 @@
from pathlib import Path from pathlib import Path
from pytorch_lightning.logging.base import LightningLoggerBase from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.logging.neptune import NeptuneLogger from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.logging.test_tube import TestTubeLogger from pytorch_lightning.loggers.test_tube import TestTubeLogger
from lib.utils.config import Config from lib.utils.config import Config

View File

@ -1,5 +1,7 @@
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
import torch
from natsort import natsorted from natsort import natsorted
from torch import nn from torch import nn
@ -35,30 +37,25 @@ class ModelParameters(Namespace):
class SavedLightningModels(object): class SavedLightningModels(object):
@classmethod @classmethod
def load_checkpoint(cls, models_root_path, model, n=-1, tags_file_path=''): def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!' assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt')) found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name) found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
if model is None:
model = torch.load(models_root_path / 'model_class.obj')
assert model is not None
if not tags_file_path: return cls(weights=found_checkpoints[n], model=model)
tag_files = models_root_path.rglob('meta_tags.csv')
tags_file_path = list(tag_files)[0]
return cls(weights=found_checkpoints[n], model=model, tags=tags_file_path)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.weights: str = kwargs.get('weights', '') self.weights: str = kwargs.get('weights', '')
self.tags: str = kwargs.get('tags', '')
self.model = kwargs.get('model', None) self.model = kwargs.get('model', None)
assert self.model is not None assert self.model is not None
def restore(self): def restore(self):
pretrained_model = self.model.load_from_metrics( pretrained_model = self.model.load_from_checkpoint(self.weights)
weights_path=self.weights,
tags_csv=self.tags
)
pretrained_model.eval() pretrained_model.eval()
pretrained_model.freeze() pretrained_model.freeze()
return pretrained_model return pretrained_model

View File

@ -1,5 +1,5 @@
from pathlib import Path from pathlib import Path
_ROOT = Path('..') _ROOT = Path('..')
HOMOTOPIC = 0 HOMOTOPIC = 1
ALTERNATIVE = 1 ALTERNATIVE = 0

26
main.py
View File

@ -10,13 +10,11 @@ import warnings
import torch import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from lib.modules.utils import LightningBaseModule from lib.modules.utils import LightningBaseModule
from lib.utils.config import Config from lib.utils.config import Config
from lib.utils.logging import Logger from lib.utils.logging import Logger
from lib.utils.model_io import SavedLightningModels
from lib.evaluation.classification import ROCEvaluation
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
@ -36,8 +34,8 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters # Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="") main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="") main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
# Transformations # Transformations
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="") main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
@ -50,10 +48,11 @@ main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="") main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
# Model # Model
main_arg_parser.add_argument("--model_type", type=str, default="classifier_cnn", help="") main_arg_parser.add_argument("--model_type", type=str, default="generator_cnn", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="") main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=2, help="")
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="") main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
@ -78,8 +77,9 @@ def run_lightning_loop(config_obj):
# Checkpoint Saving # Checkpoint Saving
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
filepath=str(logger.log_dir / 'ckpt_weights'), filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True, save_top_k=5, verbose=True, save_top_k=0,
) )
# ============================================================================= # =============================================================================
# Early Stopping # Early Stopping
# TODO: For This to work, one must set a validation step and End Eval and Score # TODO: For This to work, one must set a validation step and End Eval and Score
@ -94,6 +94,11 @@ def run_lightning_loop(config_obj):
# Init # Init
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters) model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
model.init_weights() model.init_weights()
if model.name == 'CNNRouteGenerator':
# ToDo: Make this dependent on the used seed
path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'version_0')
disc_model = SavedLightningModels.load_checkpoint(path).restore()
model.set_discriminator(disc_model)
# Trainer # Trainer
# ============================================================================= # =============================================================================
@ -101,8 +106,8 @@ def run_lightning_loop(config_obj):
show_progress_bar=True, show_progress_bar=True,
weights_save_path=logger.log_dir, weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None, gpus=[0] if torch.cuda.is_available() else None,
row_log_interval=(model.data_len * 0.01), # TODO: Better Value / Setting # row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting # log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback, checkpoint_callback=checkpoint_callback,
logger=logger, logger=logger,
fast_dev_run=config_obj.main.debug, fast_dev_run=config_obj.main.debug,
@ -110,7 +115,7 @@ def run_lightning_loop(config_obj):
) )
# Train It # Train It
trainer.fit(model,) trainer.fit(model)
# Save the last state & all parameters # Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt') trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
@ -118,6 +123,7 @@ def run_lightning_loop(config_obj):
# Evaluate It # Evaluate It
trainer.test() trainer.test()
return model return model

View File

@ -16,7 +16,7 @@ if __name__ == '__main__':
# Model Settings # Model Settings
config = Config().read_namespace(args) config = Config().read_namespace(args)
# use_bias, activation, model, use_norm, max_epochs, filters # use_bias, activation, model, use_norm, max_epochs, filters
cnn_classifier = dict(train_epochs=100, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu', cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512) model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters # use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters