project Refactor, CNN Classifier Basics
This commit is contained in:
7
.idea/deployment.xml
generated
7
.idea/deployment.xml
generated
@ -2,6 +2,13 @@
|
|||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22" showAutoUploadSettingsWarning="false">
|
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22" showAutoUploadSettingsWarning="false">
|
||||||
<serverData>
|
<serverData>
|
||||||
|
<paths name="erlowa@aimachine">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
<paths name="steffen@aimachine:22">
|
<paths name="steffen@aimachine:22">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
<mappings>
|
<mappings>
|
||||||
|
9
.idea/dictionaries/steffen.xml
generated
9
.idea/dictionaries/steffen.xml
generated
@ -1,13 +1,22 @@
|
|||||||
<component name="ProjectDictionaryState">
|
<component name="ProjectDictionaryState">
|
||||||
<dictionary name="steffen">
|
<dictionary name="steffen">
|
||||||
<words>
|
<words>
|
||||||
|
<w>autopad</w>
|
||||||
<w>conv</w>
|
<w>conv</w>
|
||||||
|
<w>convolutional</w>
|
||||||
<w>dataloader</w>
|
<w>dataloader</w>
|
||||||
|
<w>dataloaders</w>
|
||||||
|
<w>datasets</w>
|
||||||
<w>homotopic</w>
|
<w>homotopic</w>
|
||||||
<w>hparams</w>
|
<w>hparams</w>
|
||||||
<w>hyperparamter</w>
|
<w>hyperparamter</w>
|
||||||
|
<w>kingma</w>
|
||||||
|
<w>logvar</w>
|
||||||
|
<w>mapname</w>
|
||||||
|
<w>mapnames</w>
|
||||||
<w>numlayers</w>
|
<w>numlayers</w>
|
||||||
<w>reparameterize</w>
|
<w>reparameterize</w>
|
||||||
|
<w>softmax</w>
|
||||||
<w>traj</w>
|
<w>traj</w>
|
||||||
</words>
|
</words>
|
||||||
</dictionary>
|
</dictionary>
|
||||||
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="hom_traj_gen@aimachine" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="Remote Python 3.7.6 (sftp://steffen@aimachine:22/home/steffen/envs/traj_gen/bin/python)" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
5
.idea/misc.xml
generated
5
.idea/misc.xml
generated
@ -3,5 +3,8 @@
|
|||||||
<component name="JavaScriptSettings">
|
<component name="JavaScriptSettings">
|
||||||
<option name="languageLevel" value="ES6" />
|
<option name="languageLevel" value="ES6" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@ai-machine" project-jdk-type="Python SDK" />
|
||||||
|
<component name="PyPackaging">
|
||||||
|
<option name="earlyReleasesAsUpgrades" value="true" />
|
||||||
|
</component>
|
||||||
</project>
|
</project>
|
@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from random import choice
|
||||||
from torch.utils.data import ConcatDataset, Dataset
|
from torch.utils.data import ConcatDataset, Dataset
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
@ -16,10 +17,11 @@ class TrajDataset(Dataset):
|
|||||||
return self.map.as_array.shape
|
return self.map.as_array.shape
|
||||||
|
|
||||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
||||||
length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs):
|
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs):
|
||||||
super(TrajDataset, self).__init__()
|
super(TrajDataset, self).__init__()
|
||||||
|
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
|
||||||
self.preserve_equal_samples = preserve_equal_samples
|
self.preserve_equal_samples = preserve_equal_samples
|
||||||
self.all_in_map = all_in_map
|
self.mode = mode
|
||||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||||
self.maps_root = maps_root
|
self.maps_root = maps_root
|
||||||
self._len = length
|
self._len = length
|
||||||
@ -31,8 +33,19 @@ class TrajDataset(Dataset):
|
|||||||
return self._len
|
return self._len
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
|
|
||||||
|
if self.mode.lower() == 'just_route':
|
||||||
trajectory = self.map.get_random_trajectory()
|
trajectory = self.map.get_random_trajectory()
|
||||||
|
label = choice([0, 1])
|
||||||
|
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||||
|
for index in trajectory.vertices:
|
||||||
|
blank_trajectory_space[index] = 1
|
||||||
|
|
||||||
|
map_array = torch.as_tensor(self.map.as_array).float()
|
||||||
|
return (map_array, blank_trajectory_space), label
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
trajectory = self.map.get_random_trajectory()
|
||||||
# TODO: Sanity Check this while true loop...
|
# TODO: Sanity Check this while true loop...
|
||||||
alternative = self.map.generate_alternative(trajectory)
|
alternative = self.map.generate_alternative(trajectory)
|
||||||
label = self.map.are_homotopic(trajectory, alternative)
|
label = self.map.are_homotopic(trajectory, alternative)
|
||||||
@ -42,18 +55,26 @@ class TrajDataset(Dataset):
|
|||||||
break
|
break
|
||||||
|
|
||||||
self.last_label = label
|
self.last_label = label
|
||||||
if self.all_in_map:
|
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
||||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||||
blank_alternative_space = torch.zeros(self.map.shape)
|
blank_alternative_space = torch.zeros(self.map.shape)
|
||||||
for index in trajectory.vertices:
|
for index in trajectory.vertices:
|
||||||
blank_trajectory_space[index] = 1
|
blank_trajectory_space[index] = 1
|
||||||
|
for index in alternative.vertices:
|
||||||
blank_alternative_space[index] = 1
|
blank_alternative_space[index] = 1
|
||||||
|
|
||||||
map_array = torch.as_tensor(self.map.as_array).float()
|
map_array = torch.as_tensor(self.map.as_array).float()
|
||||||
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
|
if self.mode == 'separated_arrays':
|
||||||
|
return (map_array, blank_trajectory_space, int(label)), blank_alternative_space
|
||||||
else:
|
else:
|
||||||
|
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
|
||||||
|
|
||||||
|
elif self.mode == 'vectors':
|
||||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
class TrajData(object):
|
class TrajData(object):
|
||||||
@property
|
@property
|
||||||
@ -64,7 +85,7 @@ class TrajData(object):
|
|||||||
def map_shapes_max(self):
|
def map_shapes_max(self):
|
||||||
shapes = self.map_shapes
|
shapes = self.map_shapes
|
||||||
shape_list = list(map(max, zip(*shapes)))
|
shape_list = list(map(max, zip(*shapes)))
|
||||||
if self.all_in_map:
|
if self.mode == 'all_in_map':
|
||||||
shape_list[0] += 2
|
shape_list[0] += 2
|
||||||
return shape_list
|
return shape_list
|
||||||
|
|
||||||
@ -72,10 +93,10 @@ class TrajData(object):
|
|||||||
def name(self):
|
def name(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_):
|
def __init__(self, map_root, length=100000, mode='separated_arrays', **_):
|
||||||
|
|
||||||
self.all_in_map = all_in_map
|
self.mode = mode
|
||||||
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
|
self.maps_root = Path(map_root)
|
||||||
self.length = length
|
self.length = length
|
||||||
self._dataset = self._load_datasets()
|
self._dataset = self._load_datasets()
|
||||||
|
|
||||||
@ -86,8 +107,8 @@ class TrajData(object):
|
|||||||
# find max image size among available maps:
|
# find max image size among available maps:
|
||||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||||
all_in_map=self.all_in_map, embedding_size=max_map_size,
|
mode=self.mode, embedding_size=max_map_size,
|
||||||
preserve_equal_samples=False)
|
preserve_equal_samples=True)
|
||||||
for map_file in map_files])
|
for map_file in map_files])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -6,27 +6,29 @@ class ROCEvaluation(object):
|
|||||||
|
|
||||||
linewidth = 2
|
linewidth = 2
|
||||||
|
|
||||||
def __init__(self, prepare_figure=False):
|
def __init__(self, plot_roc=False):
|
||||||
self.prepare_figure = prepare_figure
|
self.plot_roc = plot_roc
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
|
|
||||||
def __call__(self, prediction, label, plotting=False):
|
def __call__(self, prediction, label):
|
||||||
|
|
||||||
# Compute ROC curve and ROC area
|
# Compute ROC curve and ROC area
|
||||||
fpr, tpr, _ = roc_curve(prediction, label)
|
fpr, tpr, _ = roc_curve(prediction, label)
|
||||||
roc_auc = auc(fpr, tpr)
|
roc_auc = auc(fpr, tpr)
|
||||||
if plotting:
|
if self.plot_roc:
|
||||||
fig = plt.gcf()
|
_ = plt.gcf()
|
||||||
fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
||||||
|
self._prepare_fig()
|
||||||
return roc_auc, fpr, tpr
|
return roc_auc, fpr, tpr
|
||||||
|
|
||||||
def _prepare_fig(self):
|
def _prepare_fig(self):
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
fig.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
|
ax = plt.gca()
|
||||||
fig.xlim([0.0, 1.0])
|
plt.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
|
||||||
fig.ylim([0.0, 1.05])
|
plt.xlim([0.0, 1.0])
|
||||||
fig.xlabel('False Positive Rate')
|
plt.ylim([0.0, 1.05])
|
||||||
fig.ylabel('True Positive Rate')
|
plt.xlabel('False Positive Rate')
|
||||||
|
plt.ylabel('True Positive Rate')
|
||||||
fig.legend(loc="lower right")
|
fig.legend(loc="lower right")
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
from datasets.paired_dataset import TrajPairData
|
import torch
|
||||||
from lib.modules.blocks import ConvModule
|
from functools import reduce
|
||||||
from lib.modules.utils import LightningBaseModule
|
from operator import mul
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
|
from datasets.trajectory_dataset import TrajData
|
||||||
|
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||||
|
from lib.modules.utils import LightningBaseModule, Flatten
|
||||||
|
|
||||||
|
|
||||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||||
@ -8,36 +15,169 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
name = 'CNNRouteGenerator'
|
name = 'CNNRouteGenerator'
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
pass
|
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
||||||
|
|
||||||
def validation_step(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def validation_end(self, outputs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
pass
|
batch_x, label = batch_xy
|
||||||
|
|
||||||
|
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
|
||||||
|
map_array, trajectory = batch_x
|
||||||
|
|
||||||
|
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||||
|
pred_label = self.discriminator(map_stack)
|
||||||
|
|
||||||
|
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||||
|
|
||||||
|
# see Appendix B from VAE paper:
|
||||||
|
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||||
|
# https://arxiv.org/abs/1312.6114
|
||||||
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||||
|
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
|
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||||
|
return dict(loss=loss, log=dict(loss=loss,
|
||||||
|
discriminated_bce_loss=discriminated_bce_loss,
|
||||||
|
kld_loss=kld_loss)
|
||||||
|
)
|
||||||
|
|
||||||
def test_step(self, *args, **kwargs):
|
def test_step(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def discriminator(self):
|
||||||
|
if self._disc is None:
|
||||||
|
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
|
||||||
|
return self._disc
|
||||||
|
|
||||||
|
def set_discriminator(self, disc_model):
|
||||||
|
if self._disc is not None:
|
||||||
|
raise RuntimeError('Discriminator has already been set... What are trying to do?')
|
||||||
|
self._disc = disc_model
|
||||||
|
|
||||||
def __init__(self, *params):
|
def __init__(self, *params):
|
||||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = TrajPairData(self.hparams.data_param.data_root)
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route')
|
||||||
|
|
||||||
# Additional Attributes
|
# Additional Attributes
|
||||||
self.in_shape = self.dataset.map_shapes_max
|
self.in_shape = self.dataset.map_shapes_max
|
||||||
|
# Todo: Better naming and size in Parameters
|
||||||
|
self.feature_dim = 10
|
||||||
|
self._disc = None
|
||||||
|
|
||||||
# NN Nodes
|
# NN Nodes
|
||||||
|
###################################################
|
||||||
|
#
|
||||||
|
# Utils
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.criterion = nn.MSELoss()
|
||||||
|
|
||||||
|
#
|
||||||
self.conv2 = ConvModule(self.conv1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
# Map Encoder
|
||||||
conv_filters=self.hparams.model_param.filters[0])
|
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
||||||
self.conv3 = ConvModule(self.conv2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
|
||||||
conv_filters=self.hparams.model_param.filters[0])
|
conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
|
||||||
def forward(self, x):
|
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||||
pass
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[1])
|
||||||
|
|
||||||
|
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||||
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[1])
|
||||||
|
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[2])
|
||||||
|
|
||||||
|
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||||
|
conv_padding=1, conv_filters=self.hparams.model_param.filters[2])
|
||||||
|
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[2]*2)
|
||||||
|
|
||||||
|
self.map_flat = Flatten(self.map_conv_3.shape)
|
||||||
|
|
||||||
|
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Trajectory Encoder
|
||||||
|
self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
|
||||||
|
self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
|
||||||
|
self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
|
||||||
|
self.traj_flat = Flatten(self.traj_conv_3.shape)
|
||||||
|
|
||||||
|
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Variational Bottleneck
|
||||||
|
self.mu = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
|
||||||
|
self.logvar = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Alternative Generator
|
||||||
|
self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
|
||||||
|
self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, self.traj_conv_3.shape))
|
||||||
|
|
||||||
|
self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape)
|
||||||
|
|
||||||
|
self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2],
|
||||||
|
conv_padding=0, conv_kernel=5, conv_stride=1)
|
||||||
|
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
|
||||||
|
conv_padding=0, conv_kernel=3, conv_stride=1)
|
||||||
|
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
|
||||||
|
conv_padding=1, conv_kernel=3, conv_stride=1)
|
||||||
|
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
|
||||||
|
conv_padding=1, conv_kernel=3, conv_stride=1)
|
||||||
|
|
||||||
|
def forward(self, batch_x):
|
||||||
|
#
|
||||||
|
# Sorting the Input
|
||||||
|
map_array, trajectory, label = batch_x
|
||||||
|
|
||||||
|
#
|
||||||
|
# Encode
|
||||||
|
map_tensor = self.map_conv_0(map_array)
|
||||||
|
map_tensor = self.map_res_1(map_tensor)
|
||||||
|
map_tensor = self.map_conv_1(map_tensor)
|
||||||
|
map_tensor = self.map_res_2(map_tensor)
|
||||||
|
map_tensor = self.map_conv_2(map_tensor)
|
||||||
|
map_tensor = self.map_res_3(map_tensor)
|
||||||
|
map_tensor = self.map_conv_3(map_tensor)
|
||||||
|
map_tensor = self.map_flat(map_tensor)
|
||||||
|
map_tensor = self.map_lin(map_tensor)
|
||||||
|
|
||||||
|
traj_tensor = self.traj_conv_1(trajectory)
|
||||||
|
traj_tensor = self.traj_conv_2(traj_tensor)
|
||||||
|
traj_tensor = self.traj_conv_3(traj_tensor)
|
||||||
|
traj_tensor = self.traj_flat(traj_tensor)
|
||||||
|
traj_tensor = self.traj_lin(traj_tensor)
|
||||||
|
|
||||||
|
mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
|
||||||
|
mixed_tensor = self.relu(mixed_tensor)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Parameter and Sampling
|
||||||
|
mu = self.mu(mixed_tensor)
|
||||||
|
logvar = self.logvar(mixed_tensor)
|
||||||
|
z = self.reparameterize(mu, logvar)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Generate
|
||||||
|
alt_tensor = self.alt_lin_1(z)
|
||||||
|
alt_tensor = self.alt_lin_2(alt_tensor)
|
||||||
|
alt_tensor = self.reshape_to_map(alt_tensor)
|
||||||
|
alt_tensor = self.alt_deconv_1(alt_tensor)
|
||||||
|
alt_tensor = self.alt_deconv_2(alt_tensor)
|
||||||
|
alt_tensor = self.alt_deconv_3(alt_tensor)
|
||||||
|
alt_tensor = self.alt_deconv_out(alt_tensor)
|
||||||
|
|
||||||
|
return alt_tensor, z, mu, logvar
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reparameterize(mu, logvar):
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
eps = torch.randn_like(std)
|
||||||
|
return mu + eps * std
|
||||||
|
@ -1,14 +1,16 @@
|
|||||||
from lib.modules.blocks import LightningBaseModule
|
|
||||||
from lib.modules.losses import BinaryHomotopicLoss
|
from lib.modules.losses import BinaryHomotopicLoss
|
||||||
|
from lib.modules.utils import LightningBaseModule
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
from lib.objects.trajectory import Trajectory
|
from lib.objects.trajectory import Trajectory
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
nn.MSELoss
|
|
||||||
|
|
||||||
class LinearRouteGeneratorModel(LightningBaseModule):
|
class LinearRouteGeneratorModel(LightningBaseModule):
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
pass
|
||||||
|
|
||||||
name = 'LinearRouteGenerator'
|
name = 'LinearRouteGenerator'
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
@ -33,6 +35,12 @@ class LinearRouteGeneratorModel(LightningBaseModule):
|
|||||||
pred_y = self(map_x, traj_x, label_x)
|
pred_y = self(map_x, traj_x, label_x)
|
||||||
|
|
||||||
loss = self.loss(traj_x, pred_y)
|
loss = self.loss(traj_x, pred_y)
|
||||||
|
|
||||||
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
|
batch_x, batch_y = batch_xy
|
||||||
|
pred_y = self(batch_x)
|
||||||
|
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
|
||||||
|
|
||||||
return dict(loss=loss, log=dict(loss=loss))
|
return dict(loss=loss, log=dict(loss=loss))
|
||||||
|
|
||||||
def test_step(self, *args, **kwargs):
|
def test_step(self, *args, **kwargs):
|
||||||
@ -41,7 +49,7 @@ class LinearRouteGeneratorModel(LightningBaseModule):
|
|||||||
def __init__(self, *params):
|
def __init__(self, *params):
|
||||||
super(LinearRouteGeneratorModel, self).__init__(*params)
|
super(LinearRouteGeneratorModel, self).__init__(*params)
|
||||||
|
|
||||||
self.loss = BinaryHomotopicLoss(self.map_storage)
|
self.criterion = BinaryHomotopicLoss(self.map_storage)
|
||||||
|
|
||||||
def forward(self, map_x, traj_x, label_x):
|
def forward(self, map_x, traj_x, label_x):
|
||||||
pass
|
pass
|
||||||
|
@ -24,41 +24,44 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
batch_x, batch_y = batch_xy
|
batch_x, batch_y = batch_xy
|
||||||
pred_y = self(batch_x)
|
pred_y = self(batch_x)
|
||||||
loss = F.binary_cross_entropy(pred_y, batch_y.float())
|
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
|
||||||
return {'loss': loss, 'log': dict(loss=loss)}
|
return {'loss': loss, 'log': dict(loss=loss)}
|
||||||
|
|
||||||
def test_step(self, batch_xy, **kwargs):
|
def test_step(self, batch_xy, batch_nb, **kwargs):
|
||||||
batch_x, batch_y = batch_xy
|
batch_x, batch_y = batch_xy
|
||||||
pred_y = self(batch_x)
|
pred_y = self(batch_x)
|
||||||
return dict(prediction=pred_y, label=batch_y)
|
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
|
||||||
|
|
||||||
def test_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
evaluation = ROCEvaluation()
|
evaluation = ROCEvaluation(plot_roc=True)
|
||||||
predictions = torch.stack([x['prediction'] for x in outputs])
|
predictions = torch.cat([x['prediction'] for x in outputs])
|
||||||
labels = torch.stack([x['label'] for x in outputs])
|
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||||
|
|
||||||
scores = evaluation(predictions.numpy(), labels.numpy(), )
|
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||||
self.logger.log_metrics({key:value for key, value in zip(['roc_auc', 'tpr', 'fpr'], scores)})
|
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
|
||||||
|
score_dict = dict(roc_auc=roc_auc)
|
||||||
|
# self.logger.log_metrics(score_dict)
|
||||||
self.logger.log_image(f'{self.name}', plt.gcf())
|
self.logger.log_image(f'{self.name}', plt.gcf())
|
||||||
pass
|
|
||||||
|
|
||||||
def __init__(self, *params):
|
return dict(log=score_dict)
|
||||||
super(ConvHomDetector, self).__init__(*params)
|
|
||||||
|
def __init__(self, hparams):
|
||||||
|
super(ConvHomDetector, self).__init__(hparams)
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = TrajData(self.hparams.data_param.root)
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map')
|
||||||
|
|
||||||
# Additional Attributes
|
# Additional Attributes
|
||||||
self.map_shape = self.dataset.map_shapes_max
|
self.map_shape = self.dataset.map_shapes_max
|
||||||
|
|
||||||
# Model Paramters
|
# Model Parameters
|
||||||
self.in_shape = self.dataset.map_shapes_max
|
self.in_shape = self.dataset.map_shapes_max
|
||||||
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
||||||
|
self.criterion = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# NN Nodes
|
# NN Nodes
|
||||||
# ============================
|
# ============================
|
||||||
# Convolutional Map Processing
|
# Convolutional Map Processing
|
||||||
#
|
|
||||||
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1,
|
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1,
|
||||||
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||||
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3,
|
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3,
|
||||||
@ -86,7 +89,6 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
|
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
|
||||||
# Comments on Multi Class labels
|
# Comments on Multi Class labels
|
||||||
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
|
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
|
||||||
self.out_activation = nn.Sigmoid() # nn.Softmax
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = self.map_conv_0(x)
|
tensor = self.map_conv_0(x)
|
||||||
@ -98,25 +100,4 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
tensor = self.flatten(tensor)
|
tensor = self.flatten(tensor)
|
||||||
tensor = self.linear(tensor)
|
tensor = self.linear(tensor)
|
||||||
tensor = self.classifier(tensor)
|
tensor = self.classifier(tensor)
|
||||||
tensor = self.out_activation(tensor)
|
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
# Dataloaders
|
|
||||||
# ================================================================================
|
|
||||||
# Train Dataloader
|
|
||||||
def train_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
|
||||||
batch_size=self.hparams.data_param.batchsize,
|
|
||||||
num_workers=self.hparams.data_param.worker)
|
|
||||||
|
|
||||||
# Test Dataloader
|
|
||||||
def test_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
|
||||||
batch_size=self.hparams.data_param.batchsize,
|
|
||||||
num_workers=self.hparams.data_param.worker)
|
|
||||||
|
|
||||||
# Validation Dataloader
|
|
||||||
def val_dataloader(self):
|
|
||||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
|
||||||
batch_size=self.hparams.data_param.batchsize,
|
|
||||||
num_workers=self.hparams.data_param.worker)
|
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
from abc import ABC
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
from lib.modules.utils import AutoPad, Interpolate
|
from lib.modules.utils import AutoPad, Interpolate
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -26,12 +22,12 @@ class ConvModule(nn.Module):
|
|||||||
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
||||||
super(ConvModule, self).__init__()
|
super(ConvModule, self).__init__()
|
||||||
|
|
||||||
# Module Paramters
|
# Module Parameters
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||||
self.activation = activation()
|
self.activation = activation()
|
||||||
|
|
||||||
# Convolution Paramters
|
# Convolution Parameters
|
||||||
self.padding = conv_padding
|
self.padding = conv_padding
|
||||||
self.stride = conv_stride
|
self.stride = conv_stride
|
||||||
|
|
||||||
@ -44,7 +40,7 @@ class ConvModule(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.norm(x) if self.norm else x
|
x = self.norm(x)
|
||||||
|
|
||||||
tensor = self.conv(x)
|
tensor = self.conv(x)
|
||||||
tensor = self.dropout(tensor)
|
tensor = self.dropout(tensor)
|
||||||
@ -100,13 +96,13 @@ class ResidualModule(nn.Module):
|
|||||||
output = self(x)
|
output = self(x)
|
||||||
return output.shape[1:]
|
return output.shape[1:]
|
||||||
|
|
||||||
def __init__(self, in_shape, module_class, n, activation=None, **module_paramters):
|
def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
|
||||||
assert n >= 1
|
assert n >= 1
|
||||||
super(ResidualModule, self).__init__()
|
super(ResidualModule, self).__init__()
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
module_paramters.update(in_shape=in_shape)
|
module_parameters.update(in_shape=in_shape)
|
||||||
self.activation = activation() if activation else lambda x: x
|
self.activation = activation() if activation else lambda x: x
|
||||||
self.residual_block = nn.ModuleList([module_class(**module_paramters) for _ in range(n)])
|
self.residual_block = nn.ModuleList([module_class(**module_parameters) for _ in range(n)])
|
||||||
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -143,5 +139,3 @@ class RecurrentModule(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = self.rnn(x)
|
tensor = self.rnn(x)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from lib.modules.utils import FlipTensor
|
from lib.modules.utils import FlipTensor
|
||||||
from lib.objects.map import MapStorage
|
from lib.objects.map import MapStorage, Map
|
||||||
|
from lib.objects.trajectory import Trajectory
|
||||||
|
|
||||||
|
|
||||||
class BinaryHomotopicLoss(nn.Module):
|
class BinaryHomotopicLoss(nn.Module):
|
||||||
@ -11,7 +14,10 @@ class BinaryHomotopicLoss(nn.Module):
|
|||||||
self.map_storage = map_storage
|
self.map_storage = map_storage
|
||||||
self.flipper = FlipTensor()
|
self.flipper = FlipTensor()
|
||||||
|
|
||||||
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
|
def forward(self, x: torch.Tensor, y: torch.Tensor, mapnames: str):
|
||||||
y_flipepd = self.flipper(y)
|
maps: List[Map] = [self.map_storage[mapname] for mapname in mapnames]
|
||||||
circle = torch.cat((x, y_flipepd), dim=-1)
|
for basemap in maps:
|
||||||
masp = self.map_storage[mapnames].are
|
basemap = basemap.as_2d_array
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,9 +83,9 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
print(e)
|
print(e)
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def __init__(self, params):
|
def __init__(self, hparams):
|
||||||
super(LightningBaseModule, self).__init__()
|
super(LightningBaseModule, self).__init__()
|
||||||
self.hparams = params
|
self.hparams = hparams
|
||||||
|
|
||||||
# Data loading
|
# Data loading
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -109,6 +109,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
def data_len(self):
|
def data_len(self):
|
||||||
return len(self.dataset.train_dataset)
|
return len(self.dataset.train_dataset)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_train_batches(self):
|
||||||
|
return len(self.train_dataloader())
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -121,7 +125,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
def test_step(self, *args, **kwargs):
|
def test_step(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def test_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
@ -134,6 +138,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
m.bias.data.fill_(0.01)
|
m.bias.data.fill_(0.01)
|
||||||
self.apply(_weight_init)
|
self.apply(_weight_init)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
# ================================================================================
|
||||||
|
# Train Dataloader
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||||
|
batch_size=self.hparams.data_param.batchsize,
|
||||||
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
|
# Test Dataloader
|
||||||
|
def test_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||||
|
batch_size=self.hparams.data_param.batchsize,
|
||||||
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
|
# Validation Dataloader
|
||||||
|
def val_dataloader(self):
|
||||||
|
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||||
|
batch_size=self.hparams.data_param.batchsize,
|
||||||
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
|
|
||||||
class FilterLayer(nn.Module):
|
class FilterLayer(nn.Module):
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ import networkx as nx
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from lib.objects.trajectory import Trajectory
|
from lib.objects.trajectory import Trajectory
|
||||||
|
import lib.variables as V
|
||||||
|
|
||||||
|
|
||||||
class Map(object):
|
class Map(object):
|
||||||
@ -145,14 +146,14 @@ class Map(object):
|
|||||||
|
|
||||||
img = Image.new('L', (self.height, self.width), 0)
|
img = Image.new('L', (self.height, self.width), 0)
|
||||||
draw = ImageDraw.Draw(img)
|
draw = ImageDraw.Draw(img)
|
||||||
draw.polygon(polyline, outline=1, fill=1)
|
draw.polygon(polyline, outline=self.white, fill=self.white)
|
||||||
|
|
||||||
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.white, 1, 0)).sum()
|
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
|
||||||
|
|
||||||
if a:
|
if a:
|
||||||
return False # Non-Homotoph
|
return V.ALTERNATIVE # Non-Homotoph
|
||||||
else:
|
else:
|
||||||
return True # Homotoph
|
return V.HOMOTOPIC # Homotoph
|
||||||
|
|
||||||
def draw(self):
|
def draw(self):
|
||||||
fig, ax = plt.gcf(), plt.gca()
|
fig, ax = plt.gcf(), plt.gca()
|
||||||
|
@ -6,6 +6,7 @@ from lib import variables as V
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class Trajectory(object):
|
class Trajectory(object):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -57,7 +58,8 @@ class Trajectory(object):
|
|||||||
def draw(self, highlights=True, label=None, **kwargs):
|
def draw(self, highlights=True, label=None, **kwargs):
|
||||||
if label is not None:
|
if label is not None:
|
||||||
kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
|
kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
|
||||||
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative')
|
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative',
|
||||||
|
lw=1)
|
||||||
if highlights:
|
if highlights:
|
||||||
kwargs.update(marker='o')
|
kwargs.update(marker='o')
|
||||||
fig, ax = plt.gcf(), plt.gca()
|
fig, ax = plt.gcf(), plt.gca()
|
||||||
|
@ -5,6 +5,7 @@ from collections import defaultdict
|
|||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lib.models.generators.cnn import CNNRouteGeneratorModel
|
||||||
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
||||||
from lib.utils.model_io import ModelParameters
|
from lib.utils.model_io import ModelParameters
|
||||||
|
|
||||||
@ -27,7 +28,7 @@ class Config(ConfigParser):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def model_class(self):
|
def model_class(self):
|
||||||
model_dict = dict(classifier_cnn=ConvHomDetector)
|
model_dict = dict(classifier_cnn=ConvHomDetector, generator_cnn=CNNRouteGeneratorModel)
|
||||||
try:
|
try:
|
||||||
return model_dict[self.get('model', 'type')]
|
return model_dict[self.get('model', 'type')]
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from pytorch_lightning.logging.base import LightningLoggerBase
|
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||||
from pytorch_lightning.logging.neptune import NeptuneLogger
|
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||||
from pytorch_lightning.logging.test_tube import TestTubeLogger
|
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||||
|
|
||||||
from lib.utils.config import Config
|
from lib.utils.config import Config
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@ -35,30 +37,25 @@ class ModelParameters(Namespace):
|
|||||||
class SavedLightningModels(object):
|
class SavedLightningModels(object):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_checkpoint(cls, models_root_path, model, n=-1, tags_file_path=''):
|
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
|
||||||
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
|
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
|
||||||
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
|
||||||
|
|
||||||
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
|
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
|
||||||
|
if model is None:
|
||||||
|
model = torch.load(models_root_path / 'model_class.obj')
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
if not tags_file_path:
|
return cls(weights=found_checkpoints[n], model=model)
|
||||||
tag_files = models_root_path.rglob('meta_tags.csv')
|
|
||||||
tags_file_path = list(tag_files)[0]
|
|
||||||
|
|
||||||
return cls(weights=found_checkpoints[n], model=model, tags=tags_file_path)
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.weights: str = kwargs.get('weights', '')
|
self.weights: str = kwargs.get('weights', '')
|
||||||
self.tags: str = kwargs.get('tags', '')
|
|
||||||
|
|
||||||
self.model = kwargs.get('model', None)
|
self.model = kwargs.get('model', None)
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
|
|
||||||
def restore(self):
|
def restore(self):
|
||||||
pretrained_model = self.model.load_from_metrics(
|
pretrained_model = self.model.load_from_checkpoint(self.weights)
|
||||||
weights_path=self.weights,
|
|
||||||
tags_csv=self.tags
|
|
||||||
)
|
|
||||||
pretrained_model.eval()
|
pretrained_model.eval()
|
||||||
pretrained_model.freeze()
|
pretrained_model.freeze()
|
||||||
return pretrained_model
|
return pretrained_model
|
@ -1,5 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
_ROOT = Path('..')
|
_ROOT = Path('..')
|
||||||
|
|
||||||
HOMOTOPIC = 0
|
HOMOTOPIC = 1
|
||||||
ALTERNATIVE = 1
|
ALTERNATIVE = 0
|
||||||
|
26
main.py
26
main.py
@ -10,13 +10,11 @@ import warnings
|
|||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from lib.modules.utils import LightningBaseModule
|
from lib.modules.utils import LightningBaseModule
|
||||||
from lib.utils.config import Config
|
from lib.utils.config import Config
|
||||||
from lib.utils.logging import Logger
|
from lib.utils.logging import Logger
|
||||||
|
from lib.utils.model_io import SavedLightningModels
|
||||||
from lib.evaluation.classification import ROCEvaluation
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -36,8 +34,8 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
|||||||
# Data Parameters
|
# Data Parameters
|
||||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
|
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
||||||
|
|
||||||
# Transformations
|
# Transformations
|
||||||
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
||||||
@ -50,10 +48,11 @@ main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="
|
|||||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
main_arg_parser.add_argument("--model_type", type=str, default="classifier_cnn", help="")
|
main_arg_parser.add_argument("--model_type", type=str, default="generator_cnn", help="")
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
||||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
||||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||||
|
main_arg_parser.add_argument("--model_lat_dim", type=int, default=2, help="")
|
||||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
||||||
@ -78,8 +77,9 @@ def run_lightning_loop(config_obj):
|
|||||||
# Checkpoint Saving
|
# Checkpoint Saving
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||||
verbose=True, save_top_k=5,
|
verbose=True, save_top_k=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Early Stopping
|
# Early Stopping
|
||||||
# TODO: For This to work, one must set a validation step and End Eval and Score
|
# TODO: For This to work, one must set a validation step and End Eval and Score
|
||||||
@ -94,6 +94,11 @@ def run_lightning_loop(config_obj):
|
|||||||
# Init
|
# Init
|
||||||
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
if model.name == 'CNNRouteGenerator':
|
||||||
|
# ToDo: Make this dependent on the used seed
|
||||||
|
path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'version_0')
|
||||||
|
disc_model = SavedLightningModels.load_checkpoint(path).restore()
|
||||||
|
model.set_discriminator(disc_model)
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -101,8 +106,8 @@ def run_lightning_loop(config_obj):
|
|||||||
show_progress_bar=True,
|
show_progress_bar=True,
|
||||||
weights_save_path=logger.log_dir,
|
weights_save_path=logger.log_dir,
|
||||||
gpus=[0] if torch.cuda.is_available() else None,
|
gpus=[0] if torch.cuda.is_available() else None,
|
||||||
row_log_interval=(model.data_len * 0.01), # TODO: Better Value / Setting
|
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
||||||
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
|
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
||||||
checkpoint_callback=checkpoint_callback,
|
checkpoint_callback=checkpoint_callback,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
fast_dev_run=config_obj.main.debug,
|
fast_dev_run=config_obj.main.debug,
|
||||||
@ -110,7 +115,7 @@ def run_lightning_loop(config_obj):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Train It
|
# Train It
|
||||||
trainer.fit(model,)
|
trainer.fit(model)
|
||||||
|
|
||||||
# Save the last state & all parameters
|
# Save the last state & all parameters
|
||||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||||
@ -118,6 +123,7 @@ def run_lightning_loop(config_obj):
|
|||||||
|
|
||||||
# Evaluate It
|
# Evaluate It
|
||||||
trainer.test()
|
trainer.test()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ if __name__ == '__main__':
|
|||||||
# Model Settings
|
# Model Settings
|
||||||
config = Config().read_namespace(args)
|
config = Config().read_namespace(args)
|
||||||
# use_bias, activation, model, use_norm, max_epochs, filters
|
# use_bias, activation, model, use_norm, max_epochs, filters
|
||||||
cnn_classifier = dict(train_epochs=100, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
|
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
|
||||||
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
|
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
|
||||||
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
|
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user