Merge remote-tracking branch 'origin/master'

# Conflicts:
#	lib/visualization/generator_eval.py
This commit is contained in:
steffen
2020-03-09 22:01:17 +01:00
51 changed files with 553 additions and 405 deletions

45
.gitignore vendored
View File

@ -2,47 +2,11 @@
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
.idea/**
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
@ -63,10 +27,3 @@ com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
/.idea/inspectionProfiles/

2
.idea/.gitignore generated vendored
View File

@ -1,2 +0,0 @@
# Default ignored files
/workspace.xml

15
.idea/deployment.xml generated
View File

@ -1,15 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="steffen@aimachine:22" showAutoUploadSettingsWarning="false">
<serverData>
<paths name="steffen@aimachine:22">
<serverdata>
<mappings>
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ON_EXPLICIT_SAVE" />
</component>
</project>

View File

@ -1,14 +0,0 @@
<component name="ProjectDictionaryState">
<dictionary name="steffen">
<words>
<w>conv</w>
<w>dataloader</w>
<w>homotopic</w>
<w>hparams</w>
<w>hyperparamter</w>
<w>numlayers</w>
<w>reparameterize</w>
<w>traj</w>
</words>
</dictionary>
</component>

View File

@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="hom_traj_gen@aimachine" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -1,7 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="PROJECT_PROFILE" value="Default" />
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml generated
View File

@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="hom_traj_gen@aimachine" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated
View File

@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/hom_traj_gen.iml" filepath="$PROJECT_DIR$/.idea/hom_traj_gen.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated
View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

15
.idea/webResources.xml generated
View File

@ -1,15 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="WebResourcesPaths">
<contentEntries>
<entry url="file://$PROJECT_DIR$">
<entryData>
<resourceRoots>
<path value="file://$PROJECT_DIR$/res" />
<path value="file://$PROJECT_DIR$/data" />
</resourceRoots>
</entryData>
</entry>
</contentEntries>
</component>
</project>

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import Union, List
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map
@ -16,10 +17,11 @@ class TrajDataset(Dataset):
return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs):
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs):
super(TrajDataset, self).__init__()
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
self.preserve_equal_samples = preserve_equal_samples
self.all_in_map = all_in_map
self.mode = mode
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
@ -31,9 +33,19 @@ class TrajDataset(Dataset):
return self._len
def __getitem__(self, item):
trajectory = self.map.get_random_trajectory()
if self.mode.lower() == 'just_route':
trajectory = self.map.get_random_trajectory()
label = choice([0, 1])
blank_trajectory_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
return (map_array, blank_trajectory_space), label
while True:
# TODO: Sanity Check this while true loop...
trajectory = self.map.get_random_trajectory()
alternative = self.map.generate_alternative(trajectory)
label = self.map.are_homotopic(trajectory, alternative)
if self.preserve_equal_samples and label == self.last_label:
@ -42,18 +54,21 @@ class TrajDataset(Dataset):
break
self.last_label = label
if self.all_in_map:
blank_trajectory_space = torch.zeros(self.map.shape)
blank_alternative_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
blank_alternative_space[index] = 1
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
map_array = torch.as_tensor(self.map.as_array).float()
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
else:
if self.mode == 'separated_arrays':
return (map_array, trajectory.draw_in_array(self.map_shape), int(label)), \
alternative.draw_in_array(self.map_shape)
else:
return torch.cat((map_array, trajectory.draw_in_array(self.map_shape),
alternative.draw_in_array(self.map_shape))), int(label)
elif self.mode == 'vectors':
return trajectory.vertices, alternative.vertices, label, self.mapname
else:
raise ValueError
class TrajData(object):
@property
@ -64,7 +79,7 @@ class TrajData(object):
def map_shapes_max(self):
shapes = self.map_shapes
shape_list = list(map(max, zip(*shapes)))
if self.all_in_map:
if self.mode == 'all_in_map':
shape_list[0] += 2
return shape_list
@ -72,22 +87,22 @@ class TrajData(object):
def name(self):
return self.__class__.__name__
def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_):
def __init__(self, map_root, length=100000, mode='separated_arrays', **_):
self.all_in_map = all_in_map
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
self.mode = mode
self.maps_root = Path(map_root)
self.length = length
self._dataset = self._load_datasets()
def _load_datasets(self):
map_files = list(self.maps_root.glob('*.bmp'))
equal_split = int(self.length // len(map_files))
equal_split = int(self.length // len(map_files)) or 1
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
all_in_map=self.all_in_map, embedding_size=max_map_size,
preserve_equal_samples=False)
mode=self.mode, embedding_size=max_map_size,
preserve_equal_samples=True)
for map_file in map_files])
@property

View File

@ -1,32 +1,34 @@
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
class ROCEvaluation(object):
linewidth = 2
def __init__(self, prepare_figure=False):
self.prepare_figure = prepare_figure
self.epoch = 0
def __call__(self, prediction, label, plotting=False):
# Compute ROC curve and ROC area
fpr, tpr, _ = roc_curve(prediction, label)
roc_auc = auc(fpr, tpr)
if plotting:
fig = plt.gcf()
fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
return roc_auc, fpr, tpr
def _prepare_fig(self):
fig = plt.gcf()
fig.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
fig.xlim([0.0, 1.0])
fig.ylim([0.0, 1.05])
fig.xlabel('False Positive Rate')
fig.ylabel('True Positive Rate')
fig.legend(loc="lower right")
return fig
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
class ROCEvaluation(object):
linewidth = 2
def __init__(self, plot_roc=False):
self.plot_roc = plot_roc
self.epoch = 0
def __call__(self, prediction, label):
# Compute ROC curve and ROC area
fpr, tpr, _ = roc_curve(prediction, label)
roc_auc = auc(fpr, tpr)
if self.plot_roc:
_ = plt.gcf()
plt.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
self._prepare_fig()
return roc_auc, fpr, tpr
def _prepare_fig(self):
fig = plt.gcf()
ax = plt.gca()
plt.plot([0, 1], [0, 1], color='navy', lw=self.linewidth, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
fig.legend(loc="lower right")
return fig

View File

@ -1,6 +1,20 @@
from datasets.paired_dataset import TrajPairData
from lib.modules.blocks import ConvModule
from lib.modules.utils import LightningBaseModule
from statistics import mean
from random import choice
import torch
from functools import reduce
from operator import mul
from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
class CNNRouteGeneratorModel(LightningBaseModule):
@ -8,36 +22,248 @@ class CNNRouteGeneratorModel(LightningBaseModule):
name = 'CNNRouteGenerator'
def configure_optimizers(self):
pass
def validation_step(self, *args, **kwargs):
pass
def validation_end(self, outputs):
pass
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
pass
batch_x, label = batch_xy
def test_step(self, *args, **kwargs):
pass
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing
kld_loss /= reduce(mul, self.in_shape)
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def _test_val_epoch_end(self, outputs, test=False):
evaluation = ROCEvaluation(plot_roc=True)
pred_label = torch.cat([x['pred_label'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
if test:
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
plt.clf()
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output_E{self.current_epoch}', fig)
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
@property
def discriminator(self):
if self._disc is None:
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
return self._disc
def set_discriminator(self, disc_model):
if self._disc is not None:
raise RuntimeError('Discriminator has already been set... What are trying to do?')
self._disc = disc_model
def __init__(self, *params):
super(CNNRouteGeneratorModel, self).__init__(*params)
# Dataset
self.dataset = TrajPairData(self.hparams.data_param.data_root)
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
length=self.hparams.data_param.dataset_length)
# Additional Attributes
self.in_shape = self.dataset.map_shapes_max
# Todo: Better naming and size in Parameters
self.feature_dim = 10
self.lat_dim = self.feature_dim + self.feature_dim + 1
self._disc = None
# NN Nodes
###################################################
#
# Utils
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.criterion = nn.MSELoss()
#
# Map Encoder
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
conv_filters=self.hparams.model_param.filters[0])
self.conv2 = ConvModule(self.conv1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.conv3 = ConvModule(self.conv2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[0])
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1])
def forward(self, x):
pass
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[1])
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2])
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
conv_padding=1, conv_filters=self.hparams.model_param.filters[2])
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2]*2)
self.map_flat = Flatten(self.map_conv_3.shape)
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
#
# Trajectory Encoder
self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.traj_flat = Flatten(self.traj_conv_3.shape)
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
#
# Mixed Encoder
self.mixed_lin = nn.Linear(self.lat_dim, self.lat_dim)
#
# Variational Bottleneck
self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
self.logvar = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
#
# Alternative Generator
self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, self.traj_conv_3.shape))
self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape)
self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=5, conv_stride=1)
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
conv_padding=0, conv_kernel=3, conv_stride=1)
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
conv_padding=1, conv_kernel=3, conv_stride=1)
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
conv_padding=1, conv_kernel=3, conv_stride=1)
def forward(self, batch_x):
#
# Sorting the Input
map_array, trajectory, label = batch_x
#
# Encode
z, mu, logvar = self.encode(map_array, trajectory, label)
#
# Generate
alt_tensor = self.generate(z)
return alt_tensor, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def generate(self, z):
alt_tensor = self.alt_lin_1(z)
alt_tensor = self.alt_lin_2(alt_tensor)
alt_tensor = self.reshape_to_map(alt_tensor)
alt_tensor = self.alt_deconv_1(alt_tensor)
alt_tensor = self.alt_deconv_2(alt_tensor)
alt_tensor = self.alt_deconv_3(alt_tensor)
alt_tensor = self.alt_deconv_out(alt_tensor)
alt_tensor = self.sigmoid(alt_tensor)
return alt_tensor
def encode(self, map_array, trajectory, label):
map_tensor = self.map_conv_0(map_array)
map_tensor = self.map_res_1(map_tensor)
map_tensor = self.map_conv_1(map_tensor)
map_tensor = self.map_res_2(map_tensor)
map_tensor = self.map_conv_2(map_tensor)
map_tensor = self.map_res_3(map_tensor)
map_tensor = self.map_conv_3(map_tensor)
map_tensor = self.map_flat(map_tensor)
map_tensor = self.map_lin(map_tensor)
traj_tensor = self.traj_conv_1(trajectory)
traj_tensor = self.traj_conv_2(traj_tensor)
traj_tensor = self.traj_conv_3(traj_tensor)
traj_tensor = self.traj_flat(traj_tensor)
traj_tensor = self.traj_lin(traj_tensor)
mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
mixed_tensor = self.relu(mixed_tensor)
mixed_tensor = self.mixed_lin(mixed_tensor)
mixed_tensor = self.relu(mixed_tensor)
#
# Parameter and Sampling
mu = self.mu(mixed_tensor)
logvar = self.logvar(mixed_tensor)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def generate_random(self, n=6):
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
trajectories = [x.get_random_trajectory() for x in maps] * 2
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories]
trajectories = self._move_to_model_device(torch.stack(trajectories))
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
maps = self._move_to_model_device(torch.stack(maps))
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
return maps, trajectories, labels, self._test_val_step(([maps, trajectories], labels), -9999)

View File

@ -1,14 +1,16 @@
from lib.modules.blocks import LightningBaseModule
from lib.modules.losses import BinaryHomotopicLoss
from lib.modules.utils import LightningBaseModule
from lib.objects.map import Map
from lib.objects.trajectory import Trajectory
import torch.nn as nn
nn.MSELoss
class LinearRouteGeneratorModel(LightningBaseModule):
def test_epoch_end(self, outputs):
pass
name = 'LinearRouteGenerator'
def configure_optimizers(self):
@ -33,6 +35,12 @@ class LinearRouteGeneratorModel(LightningBaseModule):
pred_y = self(map_x, traj_x, label_x)
loss = self.loss(traj_x, pred_y)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
return dict(loss=loss, log=dict(loss=loss))
def test_step(self, *args, **kwargs):
@ -41,7 +49,7 @@ class LinearRouteGeneratorModel(LightningBaseModule):
def __init__(self, *params):
super(LinearRouteGeneratorModel, self).__init__(*params)
self.loss = BinaryHomotopicLoss(self.map_storage)
self.criterion = BinaryHomotopicLoss(self.map_storage)
def forward(self, map_x, traj_x, label_x):
pass

View File

@ -24,41 +24,45 @@ class ConvHomDetector(LightningBaseModule):
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
loss = F.binary_cross_entropy(pred_y, batch_y.float())
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
return {'loss': loss, 'log': dict(loss=loss)}
def test_step(self, batch_xy, **kwargs):
def test_step(self, batch_xy, batch_nb, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
return dict(prediction=pred_y, label=batch_y)
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
def test_end(self, outputs):
evaluation = ROCEvaluation()
predictions = torch.stack([x['prediction'] for x in outputs])
labels = torch.stack([x['label'] for x in outputs])
def test_epoch_end(self, outputs):
evaluation = ROCEvaluation(plot_roc=True)
predictions = torch.cat([x['prediction'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
scores = evaluation(predictions.numpy(), labels.numpy(), )
self.logger.log_metrics({key:value for key, value in zip(['roc_auc', 'tpr', 'fpr'], scores)})
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
score_dict = dict(roc_auc=roc_auc)
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}', plt.gcf())
pass
def __init__(self, *params):
super(ConvHomDetector, self).__init__(*params)
return dict(log=score_dict)
def __init__(self, hparams):
super(ConvHomDetector, self).__init__(hparams)
# Dataset
self.dataset = TrajData(self.hparams.data_param.root)
self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map')
# Additional Attributes
self.map_shape = self.dataset.map_shapes_max
# Model Paramters
# Model Parameters
self.in_shape = self.dataset.map_shapes_max
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
self.criterion = nn.BCELoss()
self.sigmoid = nn.Sigmoid()
# NN Nodes
# ============================
# Convolutional Map Processing
#
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 3,
@ -86,7 +90,6 @@ class ConvHomDetector(LightningBaseModule):
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
# Comments on Multi Class labels
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
self.out_activation = nn.Sigmoid() # nn.Softmax
def forward(self, x):
tensor = self.map_conv_0(x)
@ -98,25 +101,5 @@ class ConvHomDetector(LightningBaseModule):
tensor = self.flatten(tensor)
tensor = self.linear(tensor)
tensor = self.classifier(tensor)
tensor = self.out_activation(tensor)
tensor = self.sigmoid(tensor)
return tensor
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)

View File

@ -1,11 +1,7 @@
from abc import ABC
from pathlib import Path
from typing import Union
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from lib.modules.utils import AutoPad, Interpolate
#
@ -26,12 +22,12 @@ class ConvModule(nn.Module):
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__()
# Module Paramters
# Module Parameters
self.in_shape = in_shape
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.activation = activation()
# Convolution Paramters
# Convolution Parameters
self.padding = conv_padding
self.stride = conv_stride
@ -44,7 +40,7 @@ class ConvModule(nn.Module):
)
def forward(self, x):
x = self.norm(x) if self.norm else x
x = self.norm(x)
tensor = self.conv(x)
tensor = self.dropout(tensor)
@ -72,10 +68,10 @@ class DeConvModule(nn.Module):
self.in_shape = in_shape
self.conv_filters = conv_filters
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride)
@ -100,13 +96,13 @@ class ResidualModule(nn.Module):
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, module_class, n, activation=None, **module_paramters):
def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
assert n >= 1
super(ResidualModule, self).__init__()
self.in_shape = in_shape
module_paramters.update(in_shape=in_shape)
module_parameters.update(in_shape=in_shape)
self.activation = activation() if activation else lambda x: x
self.residual_block = nn.ModuleList([module_class(**module_paramters) for _ in range(n)])
self.residual_block = nn.ModuleList([module_class(**module_parameters) for _ in range(n)])
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x):
@ -143,5 +139,3 @@ class RecurrentModule(nn.Module):
def forward(self, x):
tensor = self.rnn(x)
return tensor

View File

@ -1,8 +1,11 @@
from typing import List
import torch
from torch import nn
from lib.modules.utils import FlipTensor
from lib.objects.map import MapStorage
from lib.objects.map import MapStorage, Map
from lib.objects.trajectory import Trajectory
class BinaryHomotopicLoss(nn.Module):
@ -11,7 +14,10 @@ class BinaryHomotopicLoss(nn.Module):
self.map_storage = map_storage
self.flipper = FlipTensor()
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
y_flipepd = self.flipper(y)
circle = torch.cat((x, y_flipepd), dim=-1)
masp = self.map_storage[mapnames].are
def forward(self, x: torch.Tensor, y: torch.Tensor, mapnames: str):
maps: List[Map] = [self.map_storage[mapname] for mapname in mapnames]
for basemap in maps:
basemap = basemap.as_2d_array

View File

@ -83,9 +83,9 @@ class LightningBaseModule(pl.LightningModule, ABC):
print(e)
return -1
def __init__(self, params):
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
self.hparams = params
self.hparams = hparams
# Data loading
# =============================================================================
@ -109,6 +109,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
def data_len(self):
return len(self.dataset.train_dataset)
@property
def n_train_batches(self):
return len(self.train_dataloader())
def configure_optimizers(self):
raise NotImplementedError
@ -121,7 +125,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_end(self, outputs):
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self):
@ -134,6 +138,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
m.bias.data.fill_(0.01)
self.apply(_weight_init)
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
class FilterLayer(nn.Module):

View File

@ -1,4 +1,5 @@
import shelve
from collections import UserDict
from pathlib import Path
import copy
@ -12,6 +13,7 @@ import networkx as nx
from matplotlib import pyplot as plt
from lib.objects.trajectory import Trajectory
import lib.variables as V
class Map(object):
@ -68,7 +70,7 @@ class Map(object):
# Check pixels for their color (determine if walkable)
for idx, value in np.ndenumerate(self.map_array):
if value == self.white:
if value != self.black:
# IF walkable, add node
graph.add_node(idx, count=0)
# Fully connect to all surrounding neighbors
@ -89,6 +91,7 @@ class Map(object):
if image.mode != 'L':
image = image.convert('L')
map_array = np.expand_dims(np.array(image), axis=0)
map_array = np.where(np.asarray(map_array) == cls.white, 1, 0)
if embedding_size:
assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}'
embedding = np.zeros(embedding_size)
@ -145,14 +148,15 @@ class Map(object):
img = Image.new('L', (self.height, self.width), 0)
draw = ImageDraw.Draw(img)
draw.polygon(polyline, outline=1, fill=1)
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.white, 1, 0)).sum()
a = (np.asarray(img) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
if a:
return False # Non-Homotoph
return V.ALTERNATIVE # Non-Homotoph
else:
return True # Homotoph
return V.HOMOTOPIC # Homotoph
def draw(self):
fig, ax = plt.gcf(), plt.gca()
@ -163,28 +167,25 @@ class Map(object):
return dict(img=img, fig=fig, ax=ax)
class MapStorage(object):
class MapStorage(UserDict):
def __init__(self, map_root, load_all=False):
self.data = dict()
@property
def keys_list(self):
return list(super(MapStorage, self).keys())
def __init__(self, map_root, *args, **kwargs):
super(MapStorage, self).__init__(*args, **kwargs)
self.map_root = Path(map_root)
if load_all:
for map_file in self.map_root.glob('*.bmp'):
_ = self[map_file.name]
def __getitem__(self, item):
if item in hasattr(self, item):
return self.__getattribute__(item)
else:
with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
self.__setattr__(item, d['map']['map'])
return self[item]
map_files = list(self.map_root.glob('*.bmp'))
self.max_map_size = (1, ) + tuple(
reversed(
tuple(
map(
max, *[Image.open(map_file).size for map_file in map_files])
)
)
)
for map_file in map_files:
current_map = Map().from_image(map_file, embedding_size=self.max_map_size)
self.__setitem__(map_file.name, current_map)

View File

@ -1,78 +1,86 @@
from math import atan2
from typing import List, Tuple, Union
from matplotlib import pyplot as plt
from lib import variables as V
import numpy as np
class Trajectory(object):
@property
def vertices(self):
return self._vertices
@property
def xy_vertices(self):
return [(x, y) for _, y, x in self._vertices]
@property
def endpoints(self):
return self.start, self.dest
@property
def start(self):
return self._vertices[0]
@property
def dest(self):
return self._vertices[-1]
@property
def xs(self):
return [x[2] for x in self._vertices]
@property
def ys(self):
return [x[1] for x in self._vertices]
@property
def as_paired_list(self):
return list(zip(self._vertices[:-1], self._vertices[1:]))
@property
def np_vertices(self):
return [np.array(vertice) for vertice in self._vertices]
def __init__(self, vertices: Union[List[Tuple[int]], None] = None):
assert any((isinstance(vertices, list), vertices is None))
if vertices is not None:
self._vertices = vertices
pass
def is_equal_to(self, other_trajectory):
# ToDo: do further equality Checks here
return self._vertices == other_trajectory.vertices
def draw(self, highlights=True, label=None, **kwargs):
if label is not None:
kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative')
if highlights:
kwargs.update(marker='o')
fig, ax = plt.gcf(), plt.gca()
img = plt.plot(self.xs, self.ys, **kwargs)
return dict(img=img, fig=fig, ax=ax)
def min_vertices(self, vertices):
vertices, last_angle = [self.start], 0
for (x1, y1), (x2, y2) in self.as_paired_list:
current_angle = atan2(x1-x2, y1-y2)
if current_angle != last_angle:
vertices.append((x2, y2))
last_angle = current_angle
else:
continue
if vertices[-1] != self.dest:
vertices.append(self.dest)
return self.__class__(vertices=vertices)
from math import atan2
from typing import List, Tuple, Union
from matplotlib import pyplot as plt
from lib import variables as V
import numpy as np
class Trajectory(object):
@property
def vertices(self):
return self._vertices
@property
def xy_vertices(self):
return [(x, y) for _, y, x in self._vertices]
@property
def endpoints(self):
return self.start, self.dest
@property
def start(self):
return self._vertices[0]
@property
def dest(self):
return self._vertices[-1]
@property
def xs(self):
return [x[2] for x in self._vertices]
@property
def ys(self):
return [x[1] for x in self._vertices]
@property
def as_paired_list(self):
return list(zip(self._vertices[:-1], self._vertices[1:]))
def draw_in_array(self, shape):
trajectory_space = np.zeros(shape)
for index in self.vertices:
trajectory_space[index] = 1
return trajectory_space
@property
def np_vertices(self):
return [np.array(vertice) for vertice in self._vertices]
def __init__(self, vertices: Union[List[Tuple[int]], None] = None):
assert any((isinstance(vertices, list), vertices is None))
if vertices is not None:
self._vertices = vertices
pass
def is_equal_to(self, other_trajectory):
# ToDo: do further equality Checks here
return self._vertices == other_trajectory.vertices
def draw(self, highlights=True, label=None, **kwargs):
if label is not None:
kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative',
lw=1)
if highlights:
kwargs.update(marker='o')
fig, ax = plt.gcf(), plt.gca()
img = plt.plot(self.xs, self.ys, **kwargs)
return dict(img=img, fig=fig, ax=ax)
def min_vertices(self, vertices):
vertices, last_angle = [self.start], 0
for (x1, y1), (x2, y2) in self.as_paired_list:
current_angle = atan2(x1-x2, y1-y2)
if current_angle != last_angle:
vertices.append((x2, y2))
last_angle = current_angle
else:
continue
if vertices[-1] != self.dest:
vertices.append(self.dest)
return self.__class__(vertices=vertices)

View File

@ -5,6 +5,7 @@ from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
from lib.utils.model_io import ModelParameters
@ -27,7 +28,7 @@ class Config(ConfigParser):
@property
def model_class(self):
model_dict = dict(classifier_cnn=ConvHomDetector)
model_dict = dict(classifier_cnn=ConvHomDetector, generator_cnn=CNNRouteGeneratorModel)
try:
return model_dict[self.get('model', 'type')]
except KeyError as e:

View File

@ -1,8 +1,8 @@
from pathlib import Path
from pytorch_lightning.logging.base import LightningLoggerBase
from pytorch_lightning.logging.neptune import NeptuneLogger
from pytorch_lightning.logging.test_tube import TestTubeLogger
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger
from lib.utils.config import Config

View File

@ -1,5 +1,7 @@
from argparse import Namespace
from pathlib import Path
import torch
from natsort import natsorted
from torch import nn
@ -35,30 +37,25 @@ class ModelParameters(Namespace):
class SavedLightningModels(object):
@classmethod
def load_checkpoint(cls, models_root_path, model, n=-1, tags_file_path=''):
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
if model is None:
model = torch.load(models_root_path / 'model_class.obj')
assert model is not None
if not tags_file_path:
tag_files = models_root_path.rglob('meta_tags.csv')
tags_file_path = list(tag_files)[0]
return cls(weights=found_checkpoints[n], model=model, tags=tags_file_path)
return cls(weights=found_checkpoints[n], model=model)
def __init__(self, **kwargs):
self.weights: str = kwargs.get('weights', '')
self.tags: str = kwargs.get('tags', '')
self.model = kwargs.get('model', None)
assert self.model is not None
def restore(self):
pretrained_model = self.model.load_from_metrics(
weights_path=self.weights,
tags_csv=self.tags
)
pretrained_model = self.model.load_from_checkpoint(self.weights)
pretrained_model.eval()
pretrained_model.freeze()
return pretrained_model

View File

@ -1,5 +1,5 @@
from pathlib import Path
_ROOT = Path('..')
HOMOTOPIC = 0
ALTERNATIVE = 1
HOMOTOPIC = 1
ALTERNATIVE = 0

View File

@ -19,9 +19,9 @@ class GeneratorVisualizer(object):
def _build_column_dict_list(self):
dict_list = []
for idx in range(self.maps):
image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives
label = self.labels[idx]
for idx in range(self.maps.shape[0]):
image = (self.maps[idx] + self.trajectories[idx] + self.generated_alternatives[idx]).cpu().numpy().squeeze()
label = int(self.labels[idx])
dict_list.append(dict(image=image, label=label))
half_size = int(len(dict_list) // 2)
return dict_list[:half_size], dict_list[half_size:]
@ -33,10 +33,10 @@ class GeneratorVisualizer(object):
axes_pad=0.2, # pad between axes in inch.
)
for idx in grid.axes_all:
for idx in range(len(grid.axes_all)):
row, col = divmod(idx, len(self.column_dict_list))
current_image = self.column_dict_list[col]['image'][row]
current_label = self.column_dict_list[col]['label'][row]
current_image = self.column_dict_list[col][row]['image']
current_label = self.column_dict_list[col][row]['label']
grid[idx].imshow(current_image)
grid[idx].title.set_text(current_label)
fig.cbar_mode = 'single'

38
main.py
View File

@ -10,13 +10,11 @@ import warnings
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from lib.modules.utils import LightningBaseModule
from lib.utils.config import Config
from lib.utils.logging import Logger
from lib.evaluation.classification import ROCEvaluation
from lib.utils.model_io import SavedLightningModels
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -30,14 +28,14 @@ main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
# Main Parameters
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
main_arg_parser.add_argument("--data_root", type=str, default='/data/', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='/res/maps', help="")
main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
# Transformations
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
@ -45,15 +43,16 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
# Transformations
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=12, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
# Model
main_arg_parser.add_argument("--model_type", type=str, default="classifier_cnn", help="")
main_arg_parser.add_argument("--model_type", type=str, default="generator_cnn", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=2, help="")
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
@ -77,9 +76,10 @@ def run_lightning_loop(config_obj):
# =============================================================================
# Checkpoint Saving
checkpoint_callback = ModelCheckpoint(
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True, save_top_k=5,
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True, save_top_k=0,
)
# =============================================================================
# Early Stopping
# TODO: For This to work, one must set a validation step and End Eval and Score
@ -94,6 +94,11 @@ def run_lightning_loop(config_obj):
# Init
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
model.init_weights()
if model.name == 'CNNRouteGenerator':
# ToDo: Make this dependent on the used seed
path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'version_0')
disc_model = SavedLightningModels.load_checkpoint(path).restore()
model.set_discriminator(disc_model)
# Trainer
# =============================================================================
@ -101,8 +106,9 @@ def run_lightning_loop(config_obj):
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
row_log_interval=(model.data_len * 0.01), # TODO: Better Value / Setting
log_save_interval=(model.data_len * 0.04), # TODO: Better Value / Setting
check_val_every_n_epoch=1,
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,
logger=logger,
fast_dev_run=config_obj.main.debug,
@ -110,14 +116,16 @@ def run_lightning_loop(config_obj):
)
# Train It
trainer.fit(model,)
trainer.fit(model)
# Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
model.save_to_disk(logger.log_dir)
# Evaluate It
trainer.test()
if config_obj.main.eval:
trainer.test()
return model

View File

@ -16,7 +16,7 @@ if __name__ == '__main__':
# Model Settings
config = Config().read_namespace(args)
# use_bias, activation, model, use_norm, max_epochs, filters
cnn_classifier = dict(train_epochs=100, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters

BIN
res/shapes/inverted_1.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/inverted_10.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/inverted_2.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/inverted_3.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/inverted_4.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/inverted_5.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/inverted_6.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/inverted_7.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/inverted_8.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/inverted_9.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/shapes_10.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/shapes_7.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/shapes_8.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
res/shapes/shapes_9.bmp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB