project Refactor, CNN Classifier Basics

This commit is contained in:
Steffen Illium
2020-03-08 23:46:02 +01:00
parent 75e8a61628
commit cd4fdf2de3
20 changed files with 441 additions and 239 deletions

View File

@ -1,32 +1,34 @@
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
class ROCEvaluation(object):
linewidth = 2
def __init__(self, prepare_figure=False):
self.prepare_figure = prepare_figure
self.epoch = 0
def __call__(self, prediction, label, plotting=False):
# Compute ROC curve and ROC area
fpr, tpr, _ = roc_curve(prediction, label)
roc_auc = auc(fpr, tpr)
if plotting:
fig = plt.gcf()
fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
return roc_auc, fpr, tpr
def _prepare_fig(self):
fig = plt.gcf()
fig.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
fig.xlim([0.0, 1.0])
fig.ylim([0.0, 1.05])
fig.xlabel('False Positive Rate')
fig.ylabel('True Positive Rate')
fig.legend(loc="lower right")
return fig
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
class ROCEvaluation(object):
linewidth = 2
def __init__(self, plot_roc=False):
self.plot_roc = plot_roc
self.epoch = 0
def __call__(self, prediction, label):
# Compute ROC curve and ROC area
fpr, tpr, _ = roc_curve(prediction, label)
roc_auc = auc(fpr, tpr)
if self.plot_roc:
_ = plt.gcf()
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
self._prepare_fig()
return roc_auc, fpr, tpr
def _prepare_fig(self):
fig = plt.gcf()
ax = plt.gca()
plt.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
fig.legend(loc="lower right")
return fig

View File

@ -1,6 +1,13 @@
from datasets.paired_dataset import TrajPairData
from lib.modules.blocks import ConvModule
from lib.modules.utils import LightningBaseModule
import torch
from functools import reduce
from operator import mul
from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
class CNNRouteGeneratorModel(LightningBaseModule):
@ -8,36 +15,169 @@ class CNNRouteGeneratorModel(LightningBaseModule):
name = 'CNNRouteGenerator'
def configure_optimizers(self):
pass
def validation_step(self, *args, **kwargs):
pass
def validation_end(self, outputs):
pass
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
pass
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def test_step(self, *args, **kwargs):
pass
@property
def discriminator(self):
if self._disc is None:
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
return self._disc
def set_discriminator(self, disc_model):
if self._disc is not None:
raise RuntimeError('Discriminator has already been set... What are trying to do?')
self._disc = disc_model
def __init__(self, *params):
super(CNNRouteGeneratorModel, self).__init__(*params)
# Dataset
self.dataset = TrajPairData(self.hparams.data_param.data_root)
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route')
# Additional Attributes
self.in_shape = self.dataset.map_shapes_max
# Todo: Better naming and size in Parameters
self.feature_dim = 10
self._disc = None
# NN Nodes
###################################################
#
# Utils
self.relu = nn.ReLU()
self.criterion = nn.MSELoss()
#
# Map Encoder
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
conv_filters=self.hparams.model_param.filters[0])
self.conv2 = ConvModule(self.conv1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.conv3 = ConvModule(self.conv2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[0])
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1])
def forward(self, x):
pass
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[1])
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2])
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[2])
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2]*2)
self.map_flat = Flatten(self.map_conv_3.shape)
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
#
# Trajectory Encoder
self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_flat = Flatten(self.traj_conv_3.shape)
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
#
# Variational Bottleneck
self.mu = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
self.logvar = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
#
# Alternative Generator
self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, self.traj_conv_3.shape))
self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape)
self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=5, conv_stride=1)
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
conv_padding=0, conv_kernel=3, conv_stride=1)
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
conv_padding=1, conv_kernel=3, conv_stride=1)
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
conv_padding=1, conv_kernel=3, conv_stride=1)
def forward(self, batch_x):
#
# Sorting the Input
map_array, trajectory, label = batch_x
#
# Encode
map_tensor = self.map_conv_0(map_array)
map_tensor = self.map_res_1(map_tensor)
map_tensor = self.map_conv_1(map_tensor)
map_tensor = self.map_res_2(map_tensor)
map_tensor = self.map_conv_2(map_tensor)
map_tensor = self.map_res_3(map_tensor)
map_tensor = self.map_conv_3(map_tensor)
map_tensor = self.map_flat(map_tensor)
map_tensor = self.map_lin(map_tensor)
traj_tensor = self.traj_conv_1(trajectory)
traj_tensor = self.traj_conv_2(traj_tensor)
traj_tensor = self.traj_conv_3(traj_tensor)
traj_tensor = self.traj_flat(traj_tensor)
traj_tensor = self.traj_lin(traj_tensor)
mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
mixed_tensor = self.relu(mixed_tensor)
#
# Parameter and Sampling
mu = self.mu(mixed_tensor)
logvar = self.logvar(mixed_tensor)
z = self.reparameterize(mu, logvar)
#
# Generate
alt_tensor = self.alt_lin_1(z)
alt_tensor = self.alt_lin_2(alt_tensor)
alt_tensor = self.reshape_to_map(alt_tensor)
alt_tensor = self.alt_deconv_1(alt_tensor)
alt_tensor = self.alt_deconv_2(alt_tensor)
alt_tensor = self.alt_deconv_3(alt_tensor)
alt_tensor = self.alt_deconv_out(alt_tensor)
return alt_tensor, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std

View File

@ -1,14 +1,16 @@
from lib.modules.blocks import LightningBaseModule
from lib.modules.losses import BinaryHomotopicLoss
from lib.modules.utils import LightningBaseModule
from lib.objects.map import Map
from lib.objects.trajectory import Trajectory
import torch.nn as nn
nn.MSELoss
class LinearRouteGeneratorModel(LightningBaseModule):
def test_epoch_end(self, outputs):
pass
name = 'LinearRouteGenerator'
def configure_optimizers(self):
@ -33,6 +35,12 @@ class LinearRouteGeneratorModel(LightningBaseModule):
pred_y = self(map_x, traj_x, label_x)
loss = self.loss(traj_x, pred_y)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
return dict(loss=loss, log=dict(loss=loss))
def test_step(self, *args, **kwargs):
@ -41,7 +49,7 @@ class LinearRouteGeneratorModel(LightningBaseModule):
def __init__(self, *params):
super(LinearRouteGeneratorModel, self).__init__(*params)
self.loss = BinaryHomotopicLoss(self.map_storage)
self.criterion = BinaryHomotopicLoss(self.map_storage)
def forward(self, map_x, traj_x, label_x):
pass

View File

@ -24,41 +24,44 @@ class ConvHomDetector(LightningBaseModule):
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
loss = F.binary_cross_entropy(pred_y, batch_y.float())
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
return {'loss': loss, 'log': dict(loss=loss)}
def test_step(self, batch_xy, **kwargs):
def test_step(self, batch_xy, batch_nb, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
return dict(prediction=pred_y, label=batch_y)
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
def test_end(self, outputs):
evaluation = ROCEvaluation()
predictions = torch.stack([x['prediction'] for x in outputs])
labels = torch.stack([x['label'] for x in outputs])
def test_epoch_end(self, outputs):
evaluation = ROCEvaluation(plot_roc=True)
predictions = torch.cat([x['prediction'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
scores = evaluation(predictions.numpy(), labels.numpy(), )
self.logger.log_metrics({key:value for key, value in zip(['roc_auc', 'tpr', 'fpr'], scores)})
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
score_dict = dict(roc_auc=roc_auc)
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}', plt.gcf())
pass
def __init__(self, *params):
super(ConvHomDetector, self).__init__(*params)
return dict(log=score_dict)
def __init__(self, hparams):
super(ConvHomDetector, self).__init__(hparams)
# Dataset
self.dataset = TrajData(self.hparams.data_param.root)
self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map')
# Additional Attributes
self.map_shape = self.dataset.map_shapes_max
# Model Paramters
# Model Parameters
self.in_shape = self.dataset.map_shapes_max
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
self.criterion = nn.BCEWithLogitsLoss()
# NN Nodes
# ============================
# Convolutional Map Processing
#
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3,
@ -86,7 +89,6 @@ class ConvHomDetector(LightningBaseModule):
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
# Comments on Multi Class labels
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
self.out_activation = nn.Sigmoid() # nn.Softmax
def forward(self, x):
tensor = self.map_conv_0(x)
@ -98,25 +100,4 @@ class ConvHomDetector(LightningBaseModule):
tensor = self.flatten(tensor)
tensor = self.linear(tensor)
tensor = self.classifier(tensor)
tensor = self.out_activation(tensor)
return tensor
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)

View File

@ -1,11 +1,7 @@
from abc import ABC
from pathlib import Path
from typing import Union
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from lib.modules.utils import AutoPad, Interpolate
#
@ -26,12 +22,12 @@ class ConvModule(nn.Module):
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__()
# Module Paramters
# Module Parameters
self.in_shape = in_shape
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.activation = activation()
# Convolution Paramters
# Convolution Parameters
self.padding = conv_padding
self.stride = conv_stride
@ -44,7 +40,7 @@ class ConvModule(nn.Module):
)
def forward(self, x):
x = self.norm(x) if self.norm else x
x = self.norm(x)
tensor = self.conv(x)
tensor = self.dropout(tensor)
@ -72,10 +68,10 @@ class DeConvModule(nn.Module):
self.in_shape = in_shape
self.conv_filters = conv_filters
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride)
@ -100,13 +96,13 @@ class ResidualModule(nn.Module):
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, module_class, n, activation=None, **module_paramters):
def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
assert n >= 1
super(ResidualModule, self).__init__()
self.in_shape = in_shape
module_paramters.update(in_shape=in_shape)
module_parameters.update(in_shape=in_shape)
self.activation = activation() if activation else lambda x: x
self.residual_block = nn.ModuleList([module_class(**module_paramters) for _ in range(n)])
self.residual_block = nn.ModuleList([module_class(**module_parameters) for _ in range(n)])
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x):
@ -143,5 +139,3 @@ class RecurrentModule(nn.Module):
def forward(self, x):
tensor = self.rnn(x)
return tensor

View File

@ -1,8 +1,11 @@
from typing import List
import torch
from torch import nn
from lib.modules.utils import FlipTensor
from lib.objects.map import MapStorage
from lib.objects.map import MapStorage, Map
from lib.objects.trajectory import Trajectory
class BinaryHomotopicLoss(nn.Module):
@ -11,7 +14,10 @@ class BinaryHomotopicLoss(nn.Module):
self.map_storage = map_storage
self.flipper = FlipTensor()
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
y_flipepd = self.flipper(y)
circle = torch.cat((x, y_flipepd), dim=-1)
masp = self.map_storage[mapnames].are
def forward(self, x: torch.Tensor, y: torch.Tensor, mapnames: str):
maps: List[Map] = [self.map_storage[mapname] for mapname in mapnames]
for basemap in maps:
basemap = basemap.as_2d_array

View File

@ -83,9 +83,9 @@ class LightningBaseModule(pl.LightningModule, ABC):
print(e)
return -1
def __init__(self, params):
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
self.hparams = params
self.hparams = hparams
# Data loading
# =============================================================================
@ -109,6 +109,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
def data_len(self):
return len(self.dataset.train_dataset)
@property
def n_train_batches(self):
return len(self.train_dataloader())
def configure_optimizers(self):
raise NotImplementedError
@ -121,7 +125,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_end(self, outputs):
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self):
@ -134,6 +138,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
m.bias.data.fill_(0.01)
self.apply(_weight_init)
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
class FilterLayer(nn.Module):

View File

@ -12,6 +12,7 @@ import networkx as nx
from matplotlib import pyplot as plt
from lib.objects.trajectory import Trajectory
import lib.variables as V
class Map(object):
@ -145,14 +146,14 @@ class Map(object):
img = Image.new('L', (self.height, self.width), 0)
draw = ImageDraw.Draw(img)
draw.polygon(polyline, outline=1, fill=1)
draw.polygon(polyline, outline=self.white, fill=self.white)
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.white, 1, 0)).sum()
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
if a:
return False # Non-Homotoph
return V.ALTERNATIVE # Non-Homotoph
else:
return True # Homotoph
return V.HOMOTOPIC # Homotoph
def draw(self):
fig, ax = plt.gcf(), plt.gca()

View File

@ -1,78 +1,80 @@
from math import atan2
from typing import List, Tuple, Union
from matplotlib import pyplot as plt
from lib import variables as V
import numpy as np
class Trajectory(object):
@property
def vertices(self):
return self._vertices
@property
def xy_vertices(self):
return [(x, y) for _, y, x in self._vertices]
@property
def endpoints(self):
return self.start, self.dest
@property
def start(self):
return self._vertices[0]
@property
def dest(self):
return self._vertices[-1]
@property
def xs(self):
return [x[2] for x in self._vertices]
@property
def ys(self):
return [x[1] for x in self._vertices]
@property
def as_paired_list(self):
return list(zip(self._vertices[:-1], self._vertices[1:]))
@property
def np_vertices(self):
return [np.array(vertice) for vertice in self._vertices]
def __init__(self, vertices: Union[List[Tuple[int]], None] = None):
assert any((isinstance(vertices, list), vertices is None))
if vertices is not None:
self._vertices = vertices
pass
def is_equal_to(self, other_trajectory):
# ToDo: do further equality Checks here
return self._vertices == other_trajectory.vertices
def draw(self, highlights=True, label=None, **kwargs):
if label is not None:
kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative')
if highlights:
kwargs.update(marker='o')
fig, ax = plt.gcf(), plt.gca()
img = plt.plot(self.xs, self.ys, **kwargs)
return dict(img=img, fig=fig, ax=ax)
def min_vertices(self, vertices):
vertices, last_angle = [self.start], 0
for (x1, y1), (x2, y2) in self.as_paired_list:
current_angle = atan2(x1-x2, y1-y2)
if current_angle != last_angle:
vertices.append((x2, y2))
last_angle = current_angle
else:
continue
if vertices[-1] != self.dest:
vertices.append(self.dest)
return self.__class__(vertices=vertices)
from math import atan2
from typing import List, Tuple, Union
from matplotlib import pyplot as plt
from lib import variables as V
import numpy as np
class Trajectory(object):
@property
def vertices(self):
return self._vertices
@property
def xy_vertices(self):
return [(x, y) for _, y, x in self._vertices]
@property
def endpoints(self):
return self.start, self.dest
@property
def start(self):
return self._vertices[0]
@property
def dest(self):
return self._vertices[-1]
@property
def xs(self):
return [x[2] for x in self._vertices]
@property
def ys(self):
return [x[1] for x in self._vertices]
@property
def as_paired_list(self):
return list(zip(self._vertices[:-1], self._vertices[1:]))
@property
def np_vertices(self):
return [np.array(vertice) for vertice in self._vertices]
def __init__(self, vertices: Union[List[Tuple[int]], None] = None):
assert any((isinstance(vertices, list), vertices is None))
if vertices is not None:
self._vertices = vertices
pass
def is_equal_to(self, other_trajectory):
# ToDo: do further equality Checks here
return self._vertices == other_trajectory.vertices
def draw(self, highlights=True, label=None, **kwargs):
if label is not None:
kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative',
lw=1)
if highlights:
kwargs.update(marker='o')
fig, ax = plt.gcf(), plt.gca()
img = plt.plot(self.xs, self.ys, **kwargs)
return dict(img=img, fig=fig, ax=ax)
def min_vertices(self, vertices):
vertices, last_angle = [self.start], 0
for (x1, y1), (x2, y2) in self.as_paired_list:
current_angle = atan2(x1-x2, y1-y2)
if current_angle != last_angle:
vertices.append((x2, y2))
last_angle = current_angle
else:
continue
if vertices[-1] != self.dest:
vertices.append(self.dest)
return self.__class__(vertices=vertices)

View File

@ -5,6 +5,7 @@ from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
from lib.utils.model_io import ModelParameters
@ -27,7 +28,7 @@ class Config(ConfigParser):
@property
def model_class(self):
model_dict = dict(classifier_cnn=ConvHomDetector)
model_dict = dict(classifier_cnn=ConvHomDetector, generator_cnn=CNNRouteGeneratorModel)
try:
return model_dict[self.get('model', 'type')]
except KeyError as e:

View File

@ -1,8 +1,8 @@
from pathlib import Path
from pytorch_lightning.logging.base import LightningLoggerBase
from pytorch_lightning.logging.neptune import NeptuneLogger
from pytorch_lightning.logging.test_tube import TestTubeLogger
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger
from lib.utils.config import Config

View File

@ -1,5 +1,7 @@
from argparse import Namespace
from pathlib import Path
import torch
from natsort import natsorted
from torch import nn
@ -35,30 +37,25 @@ class ModelParameters(Namespace):
class SavedLightningModels(object):
@classmethod
def load_checkpoint(cls, models_root_path, model, n=-1, tags_file_path=''):
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
if model is None:
model = torch.load(models_root_path / 'model_class.obj')
assert model is not None
if not tags_file_path:
tag_files = models_root_path.rglob('meta_tags.csv')
tags_file_path = list(tag_files)[0]
return cls(weights=found_checkpoints[n], model=model, tags=tags_file_path)
return cls(weights=found_checkpoints[n], model=model)
def __init__(self, **kwargs):
self.weights: str = kwargs.get('weights', '')
self.tags: str = kwargs.get('tags', '')
self.model = kwargs.get('model', None)
assert self.model is not None
def restore(self):
pretrained_model = self.model.load_from_metrics(
weights_path=self.weights,
tags_csv=self.tags
)
pretrained_model = self.model.load_from_checkpoint(self.weights)
pretrained_model.eval()
pretrained_model.freeze()
return pretrained_model

View File

@ -1,5 +1,5 @@
from pathlib import Path
_ROOT = Path('..')
HOMOTOPIC = 0
ALTERNATIVE = 1
HOMOTOPIC = 1
ALTERNATIVE = 0