fig clf inserted and not resize on kld
1
.gitignore
vendored
@ -3,6 +3,7 @@
|
||||
|
||||
# User-specific stuff
|
||||
.idea/**
|
||||
res/**
|
||||
|
||||
# CMake
|
||||
cmake-build-*/
|
||||
|
29
datasets/mnist.py
Normal file
@ -0,0 +1,29 @@
|
||||
from torchvision.datasets import MNIST
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyMNIST(MNIST):
|
||||
|
||||
@property
|
||||
def map_shapes_max(self):
|
||||
return np.asarray(self.test_dataset[0][0]).shape
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MyMNIST, self).__init__('res', train=False, download=True)
|
||||
pass
|
||||
|
||||
def __getitem__(self, item):
|
||||
image = super(MyMNIST, self).__getitem__(item)
|
||||
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1]
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self.__class__('res', train=True, download=True)
|
||||
|
||||
@property
|
||||
def test_dataset(self):
|
||||
return self.__class__('res', train=False, download=True)
|
||||
|
||||
@property
|
||||
def val_dataset(self):
|
||||
return self.__class__('res', train=False, download=True)
|
@ -1,6 +1,9 @@
|
||||
import shelve
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
from typing import Union
|
||||
|
||||
from torchvision.transforms import Normalize
|
||||
|
||||
import multiprocessing as mp
|
||||
|
||||
@ -24,16 +27,17 @@ class TrajDataShelve(Dataset):
|
||||
return self[0][0].shape
|
||||
|
||||
def __init__(self, file_path, **kwargs):
|
||||
assert Path(file_path).exists()
|
||||
super(TrajDataShelve, self).__init__()
|
||||
self._mutex = mp.Lock()
|
||||
self.file_path = str(file_path)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
self._mutex.acquire()
|
||||
with shelve.open(self.file_path) as d:
|
||||
length = len(d)
|
||||
self._mutex.release()
|
||||
d.close()
|
||||
self._mutex.release()
|
||||
return length
|
||||
|
||||
def seed(self):
|
||||
@ -43,12 +47,20 @@ class TrajDataShelve(Dataset):
|
||||
self._mutex.acquire()
|
||||
with shelve.open(self.file_path) as d:
|
||||
sample = d[str(item)]
|
||||
self._mutex.release()
|
||||
d.close()
|
||||
self._mutex.release()
|
||||
return sample
|
||||
|
||||
|
||||
class TrajDataset(Dataset):
|
||||
|
||||
@property
|
||||
def _last_label_init(self):
|
||||
d = defaultdict(lambda: -1)
|
||||
d['generator_hom_all_in_map'] = V.ALTERNATIVE
|
||||
d['generator_alt_all_in_map'] = V.HOMOTOPIC
|
||||
return d[self.mode]
|
||||
|
||||
@property
|
||||
def map_shape(self):
|
||||
return self.map.as_array.shape
|
||||
@ -57,17 +69,18 @@ class TrajDataset(Dataset):
|
||||
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
|
||||
**kwargs):
|
||||
super(TrajDataset, self).__init__()
|
||||
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map'
|
||||
'classifier_all_in_map']
|
||||
self.normalized = normalized
|
||||
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map', 'generator_alt_all_in_map',
|
||||
'ae_no_label_in_map',
|
||||
'generator_alt_no_label_in_map', 'classifier_all_in_map', 'vae_no_label_in_map']
|
||||
self.normalize = Normalize(0.5, 0.5) if normalized else lambda x: x
|
||||
self.preserve_equal_samples = preserve_equal_samples
|
||||
self.mode = mode
|
||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||
self.maps_root = maps_root
|
||||
self._len = length
|
||||
self.last_label = V.ALTERNATIVE if 'hom' in self.mode else choice([-1, V.ALTERNATIVE, V.HOMOTOPIC])
|
||||
self.last_label = self._last_label_init
|
||||
|
||||
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
||||
self.map = Map.from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
@ -82,6 +95,7 @@ class TrajDataset(Dataset):
|
||||
map_array = torch.as_tensor(self.map.as_array).float()
|
||||
return (map_array, trajectory_space), label
|
||||
|
||||
# Produce an alternative.
|
||||
while True:
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
alternative = self.map.generate_alternative(trajectory)
|
||||
@ -91,18 +105,19 @@ class TrajDataset(Dataset):
|
||||
else:
|
||||
break
|
||||
|
||||
self.last_label = label if self.mode != ['generator_hom_all_in_map'] else V.ALTERNATIVE
|
||||
if self.mode.lower() in ['classifier_all_in_map', 'generator_all_in_map']:
|
||||
self.last_label = label if self._last_label_init == V.ANY else self._last_label_init[self.mode]
|
||||
if 'in_map' in self.mode.lower():
|
||||
map_array = self.map.as_array
|
||||
trajectory = trajectory.draw_in_array(self.map_shape)
|
||||
alternative = alternative.draw_in_array(self.map_shape)
|
||||
label_as_array = np.full_like(map_array, label)
|
||||
if self.normalized:
|
||||
map_array = map_array / V.WHITE
|
||||
trajectory = trajectory / V.WHITE
|
||||
alternative = alternative / V.WHITE
|
||||
|
||||
if self.mode == 'generator_all_in_map':
|
||||
return np.concatenate((map_array, trajectory, label_as_array)), alternative
|
||||
elif self.mode in ['vae_no_label_in_map', 'ae_no_label_in_map']:
|
||||
return np.sum((map_array, trajectory, alternative), axis=0), 0
|
||||
elif self.mode in ['generator_alt_no_label_in_map', 'generator_hom_no_label_in_map']:
|
||||
return np.concatenate((map_array, trajectory)), alternative
|
||||
elif self.mode == 'classifier_all_in_map':
|
||||
return np.concatenate((map_array, trajectory, alternative)), label
|
||||
|
||||
@ -119,13 +134,13 @@ class TrajDataset(Dataset):
|
||||
class TrajData(object):
|
||||
@property
|
||||
def map_shapes(self):
|
||||
return [dataset.map_shape for dataset in self._train_dataset.datasets]
|
||||
return [dataset.map_shape for dataset in self.train_dataset.datasets]
|
||||
|
||||
@property
|
||||
def map_shapes_max(self):
|
||||
shapes = self.map_shapes
|
||||
shape_list = list(map(max, zip(*shapes)))
|
||||
if '_all_in_map' in self.mode:
|
||||
if '_all_in_map' in self.mode and not self.preprocessed:
|
||||
shape_list[0] += 2
|
||||
return shape_list
|
||||
|
||||
@ -139,14 +154,13 @@ class TrajData(object):
|
||||
self.mode = mode
|
||||
self.maps_root = Path(map_root)
|
||||
self.length = length
|
||||
self._test_dataset = self._load_datasets('train')
|
||||
self._val_dataset = self._load_datasets('val')
|
||||
self._train_dataset = self._load_datasets('test')
|
||||
self.test_dataset = self._load_datasets('test')
|
||||
self.val_dataset = self._load_datasets('val')
|
||||
self.train_dataset = self._load_datasets('train')
|
||||
|
||||
def _load_datasets(self, dataset_type=''):
|
||||
|
||||
map_files = list(self.maps_root.glob('*.bmp'))
|
||||
equal_split = int(self.length // len(map_files)) or 1
|
||||
|
||||
# find max image size among available maps:
|
||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||
@ -156,10 +170,11 @@ class TrajData(object):
|
||||
preprocessed_map_names = [p.name for p in preprocessed_map_files]
|
||||
datasets = []
|
||||
for map_file in map_files:
|
||||
new_pik_name = f'{dataset_type}_{str(map_file.name)[:-3]}.pik'
|
||||
equal_split = int(self.length // len(map_files)) or 5
|
||||
new_pik_name = f'{self.mode}_{map_file.name[:-4]}_{dataset_type}.pik'
|
||||
if dataset_type != 'train':
|
||||
equal_split *= 0.01
|
||||
if not [f'{new_pik_name[:-3]}.bmp' in preprocessed_map_names]:
|
||||
equal_split = max(int(equal_split * 0.01), 10)
|
||||
if not new_pik_name in preprocessed_map_names:
|
||||
traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
||||
preserve_equal_samples=True)
|
||||
@ -168,6 +183,9 @@ class TrajData(object):
|
||||
dataset = TrajDataShelve(map_file.parent / new_pik_name)
|
||||
datasets.append(dataset)
|
||||
return ConcatDataset(datasets)
|
||||
|
||||
# Set the equal split so that all maps are visited with the same frequency
|
||||
equal_split = int(self.length // len(map_files)) or 5
|
||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
||||
preserve_equal_samples=True)
|
||||
@ -185,29 +203,14 @@ class TrajData(object):
|
||||
|
||||
def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000):
|
||||
assert str(file_path).endswith('.pik')
|
||||
processes = mp.cpu_count() - 1
|
||||
mutex = mp.Lock()
|
||||
with mp.Pool(processes) as pool:
|
||||
async_results = [pool.apply_async(traj_dataset.__getitem__, kwds=dict(item=i)) for i in range(n)]
|
||||
for i in tqdm(range(n), total=n, desc=f'Generating {n} Samples'):
|
||||
sample = traj_dataset[i]
|
||||
mutex.acquire()
|
||||
write_to_shelve(file_path, sample)
|
||||
mutex.release()
|
||||
|
||||
for result_obj in tqdm(async_results, total=n, desc=f'Generating {n} Samples'):
|
||||
sample = result_obj.get()
|
||||
mutex.acquire()
|
||||
write_to_shelve(file_path, sample)
|
||||
mutex.release()
|
||||
print(f'{n} samples sucessfully dumped to "{file_path}"!')
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self._train_dataset
|
||||
|
||||
@property
|
||||
def val_dataset(self):
|
||||
return self._val_dataset
|
||||
|
||||
@property
|
||||
def test_dataset(self):
|
||||
return self._test_dataset
|
||||
print(f'{n} samples successfully dumped to "{file_path}"!')
|
||||
|
||||
def get_datasets(self):
|
||||
return self._train_dataset, self._val_dataset, self._test_dataset
|
||||
|
@ -1,19 +1,22 @@
|
||||
from random import choices, seed
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
from random import choices, choice
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
from datasets.mnist import MyMNIST
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import lib.variables as V
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
|
||||
|
||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
@ -24,48 +27,71 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, alternative = batch_xy
|
||||
batch_x, target = batch_xy
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
element_wise_loss = self.criterion(generated_alternative, alternative)
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
target = batch_x if 'ae' in self.hparams.data_param.mode else target
|
||||
element_wise_loss = self.criterion(generated_alternative, target)
|
||||
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||
# kld_loss /= reduce(mul, self.in_shape)
|
||||
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
|
||||
if 'vae' in self.hparams.data_param.mode:
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||
# kld_loss /= reduce(mul, self.in_shape)
|
||||
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
|
||||
|
||||
loss = kld_loss + element_wise_loss
|
||||
loss = kld_loss + element_wise_loss
|
||||
else:
|
||||
loss = element_wise_loss
|
||||
kld_loss = 0
|
||||
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
||||
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, _ = batch_xy
|
||||
map_array = batch_x[:, 0].unsqueeze(1)
|
||||
trajectory = batch_x[:, 1].unsqueeze(1)
|
||||
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
|
||||
if 'vae' in self.hparams.data_param.mode:
|
||||
z, mu, logvar = self.encode(batch_x)
|
||||
else:
|
||||
z = self.encode(batch_x)
|
||||
mu, logvar = z, z
|
||||
|
||||
_, mu, _ = self.encode(batch_x)
|
||||
generated_alternative = self.generate(mu)
|
||||
return dict(maps=map_array, trajectories=trajectory, batch_nb=batch_nb, labels=labels,
|
||||
generated_alternative=generated_alternative, pred_label=-1)
|
||||
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=generated_alternative, z=z, mu=mu, logvar=logvar)
|
||||
|
||||
if 'hom' in self.hparams.data_param.mode:
|
||||
labels = torch.full((batch_x.shape[0], 1), V.HOMOTOPIC)
|
||||
elif 'alt' in self.hparams.data_param.mode:
|
||||
labels = torch.full((batch_x.shape[0], 1), V.ALTERNATIVE)
|
||||
elif 'vae' in self.hparams.data_param.mode:
|
||||
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
||||
elif 'ae' in self.hparams.data_param.mode:
|
||||
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
||||
else:
|
||||
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
|
||||
|
||||
return_dict.update(labels=self._move_to_model_device(labels))
|
||||
return return_dict
|
||||
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
val_restul_dict = self.generate_random()
|
||||
plt.close('all')
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(**val_restul_dict)
|
||||
fig = g.draw()
|
||||
g = GeneratorVisualizer(choice(outputs))
|
||||
fig = g.draw_io_bundle()
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
fig = g.draw_latent()
|
||||
self.logger.log_image(f'{self.name}_Latent', fig, step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
return dict(epoch=self.current_epoch)
|
||||
|
||||
def on_epoch_start(self):
|
||||
self.dataset.seed(self.logger.version)
|
||||
# self.dataset.seed(self.logger.version)
|
||||
# torch.random.manual_seed(self.logger.version)
|
||||
# np.random.seed(self.logger.version)
|
||||
pass
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
@ -82,19 +108,23 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
def __init__(self, *params, issubclassed=False):
|
||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
if not issubclassed:
|
||||
if False:
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='generator_all_in_map',
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root,
|
||||
mode=self.hparams.data_param.mode,
|
||||
preprocessed=self.hparams.data_param.use_preprocessed,
|
||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||
self.criterion = nn.MSELoss()
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
self.dataset = MyMNIST()
|
||||
|
||||
# Additional Attributes #
|
||||
#######################################################
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
self.use_res_net = self.hparams.model_param.use_res_net
|
||||
self.lat_dim = self.hparams.model_param.lat_dim
|
||||
self.feature_dim = self.lat_dim * 10
|
||||
self.feature_dim = self.lat_dim
|
||||
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
|
||||
########################################################
|
||||
|
||||
# NN Nodes
|
||||
@ -119,7 +149,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=2, conv_padding=0,
|
||||
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
@ -137,20 +167,8 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.enc_res_3 = ResidualModule(self.enc_conv_2b.shape, ConvModule, 2, conv_kernel=7, conv_stride=1,
|
||||
conv_padding=3, conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_3a = ConvModule(self.enc_res_3.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_3b = ConvModule(self.enc_conv_3a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.enc_flat = Flatten(self.enc_conv_3b.shape)
|
||||
last_conv_shape = self.enc_conv_2b.shape
|
||||
self.enc_flat = Flatten(last_conv_shape)
|
||||
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
|
||||
|
||||
#
|
||||
@ -160,46 +178,43 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
#
|
||||
# Variational Bottleneck
|
||||
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
if 'vae' in self.hparams.data_param.mode:
|
||||
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
|
||||
#
|
||||
# Linear Bottleneck
|
||||
else:
|
||||
self.z = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
|
||||
#
|
||||
# Alternative Generator
|
||||
self.gen_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
|
||||
self.gen_lin_1 = nn.Linear(self.lat_dim, self.enc_flat.shape)
|
||||
|
||||
self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
|
||||
# self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
|
||||
|
||||
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, self.enc_conv_3b.shape)
|
||||
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
|
||||
|
||||
self.gen_deconv_1a = DeConvModule(self.enc_conv_3b.shape, self.hparams.model_param.filters[2],
|
||||
conv_padding=0, conv_kernel=11, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
self.gen_deconv_1b = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[2],
|
||||
conv_padding=0, conv_kernel=9, conv_stride=2,
|
||||
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
|
||||
conv_padding=1, conv_kernel=9, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1b.shape, self.hparams.model_param.filters[1],
|
||||
conv_padding=0, conv_kernel=7, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
self.gen_deconv_2b = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[1],
|
||||
conv_padding=0, conv_kernel=7, conv_stride=1,
|
||||
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
|
||||
conv_padding=1, conv_kernel=7, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2b.shape, self.hparams.model_param.filters[0],
|
||||
conv_padding=1, conv_kernel=5, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
self.gen_deconv_3b = DeConvModule(self.gen_deconv_3a.shape, self.hparams.model_param.filters[0],
|
||||
conv_padding=1, conv_kernel=4, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_out = DeConvModule(self.gen_deconv_3b.shape, 1, activation=None,
|
||||
self.gen_deconv_out = DeConvModule(self.gen_deconv_2a.shape, self.out_channels, activation=None,
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
def forward(self, batch_x):
|
||||
#
|
||||
# Encode
|
||||
z, mu, logvar = self.encode(batch_x)
|
||||
if 'vae' in self.hparams.data_param.mode:
|
||||
z, mu, logvar = self.encode(batch_x)
|
||||
else:
|
||||
z = self.encode(batch_x)
|
||||
mu, logvar = z, z
|
||||
|
||||
#
|
||||
# Generate
|
||||
@ -220,148 +235,46 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
|
||||
combined_tensor = self.enc_conv_2a(combined_tensor)
|
||||
combined_tensor = self.enc_conv_2b(combined_tensor)
|
||||
combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
|
||||
combined_tensor = self.enc_conv_3a(combined_tensor)
|
||||
combined_tensor = self.enc_conv_3b(combined_tensor)
|
||||
# combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
|
||||
# combined_tensor = self.enc_conv_3a(combined_tensor)
|
||||
# combined_tensor = self.enc_conv_3b(combined_tensor)
|
||||
|
||||
combined_tensor = self.enc_flat(combined_tensor)
|
||||
combined_tensor = self.enc_lin_1(combined_tensor)
|
||||
combined_tensor = self.enc_lin_2(combined_tensor)
|
||||
|
||||
combined_tensor = self.enc_norm(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
|
||||
combined_tensor = self.enc_lin_2(combined_tensor)
|
||||
combined_tensor = self.enc_norm(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
|
||||
#
|
||||
# Variational
|
||||
# Parameter and Sampling
|
||||
mu = self.mu(combined_tensor)
|
||||
logvar = self.logvar(combined_tensor)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return z, mu, logvar
|
||||
if 'vae' in self.hparams.data_param.mode:
|
||||
mu = self.mu(combined_tensor)
|
||||
logvar = self.logvar(combined_tensor)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return z, mu, logvar
|
||||
else:
|
||||
#
|
||||
# Linear Bottleneck
|
||||
z = self.z(combined_tensor)
|
||||
return z
|
||||
|
||||
def generate(self, z):
|
||||
alt_tensor = self.gen_lin_1(z)
|
||||
alt_tensor = self.activation(alt_tensor)
|
||||
alt_tensor = self.gen_lin_2(alt_tensor)
|
||||
alt_tensor = self.activation(alt_tensor)
|
||||
# alt_tensor = self.gen_lin_2(alt_tensor)
|
||||
# alt_tensor = self.activation(alt_tensor)
|
||||
alt_tensor = self.reshape_to_last_conv(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_1a(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_1b(alt_tensor)
|
||||
|
||||
alt_tensor = self.gen_deconv_2a(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_2b(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_3a(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_3b(alt_tensor)
|
||||
|
||||
# alt_tensor = self.gen_deconv_3a(alt_tensor)
|
||||
# alt_tensor = self.gen_deconv_3b(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_out(alt_tensor)
|
||||
# alt_tensor = self.activation(alt_tensor)
|
||||
alt_tensor = self.sigmoid(alt_tensor)
|
||||
# alt_tensor = self.sigmoid(alt_tensor)
|
||||
return alt_tensor
|
||||
|
||||
def generate_random(self, n=12):
|
||||
|
||||
samples, alternatives = zip(*[self.dataset.test_dataset[choice]
|
||||
for choice in choices(range(self.dataset.length), k=n)])
|
||||
samples = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in samples]))
|
||||
alternatives = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in alternatives]))
|
||||
|
||||
return self._test_val_step((samples, alternatives), -9999)
|
||||
|
||||
|
||||
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||
|
||||
name = 'CNNRouteGeneratorDiscriminated'
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing
|
||||
kld_loss /= reduce(mul, self.in_shape)
|
||||
|
||||
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||
return dict(loss=loss, log=dict(loss=loss,
|
||||
discriminated_bce_loss=discriminated_bce_loss,
|
||||
kld_loss=kld_loss)
|
||||
)
|
||||
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
|
||||
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def validation_epoch_end(self, outputs: list):
|
||||
return self._test_val_epoch_end(outputs)
|
||||
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
evaluation = ROCEvaluation(plot_roc=True)
|
||||
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
||||
|
||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||
if test:
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
||||
|
||||
def test_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self._test_val_epoch_end(outputs, test=True)
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
if self._disc is None:
|
||||
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
|
||||
return self._disc
|
||||
|
||||
def set_discriminator(self, disc_model):
|
||||
if self._disc is not None:
|
||||
raise RuntimeError('Discriminator has already been set... What are trying to do?')
|
||||
self._disc = disc_model
|
||||
|
||||
def __init__(self, *params):
|
||||
raise NotImplementedError
|
||||
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
|
||||
|
||||
self._disc = None
|
||||
|
||||
self.criterion = nn.BCELoss()
|
||||
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True,
|
||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||
|
116
lib/models/generators/cnn_discriminated.py
Normal file
@ -0,0 +1,116 @@
|
||||
from random import choices, seed
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
from lib.models.generators.cnn import CNNRouteGeneratorModel
|
||||
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||
|
||||
name = 'CNNRouteGeneratorDiscriminated'
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing
|
||||
kld_loss /= reduce(mul, self.in_shape)
|
||||
|
||||
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||
return dict(loss=loss, log=dict(loss=loss,
|
||||
discriminated_bce_loss=discriminated_bce_loss,
|
||||
kld_loss=kld_loss)
|
||||
)
|
||||
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
|
||||
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def validation_epoch_end(self, outputs: list):
|
||||
return self._test_val_epoch_end(outputs)
|
||||
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
evaluation = ROCEvaluation(plot_roc=True)
|
||||
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
||||
|
||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||
if test:
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
||||
|
||||
def test_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self._test_val_epoch_end(outputs, test=True)
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
if self._disc is None:
|
||||
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
|
||||
return self._disc
|
||||
|
||||
def set_discriminator(self, disc_model):
|
||||
if self._disc is not None:
|
||||
raise RuntimeError('Discriminator has already been set... What are trying to do?')
|
||||
self._disc = disc_model
|
||||
|
||||
def __init__(self, *params):
|
||||
raise NotImplementedError
|
||||
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
|
||||
|
||||
self._disc = None
|
||||
|
||||
self.criterion = nn.BCELoss()
|
||||
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True,
|
||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
@ -189,5 +189,5 @@ class MapStorage(UserDict):
|
||||
)
|
||||
|
||||
for map_file in map_files:
|
||||
current_map = Map().from_image(map_file, embedding_size=self.max_map_size)
|
||||
current_map = Map.from_image(map_file, embedding_size=self.max_map_size)
|
||||
self.__setitem__(map_file.name, current_map)
|
||||
|
@ -5,7 +5,9 @@ from collections import defaultdict
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
|
||||
from lib.models.generators.cnn import CNNRouteGeneratorModel, CNNRouteGeneratorDiscriminated
|
||||
from lib.models.generators.cnn import CNNRouteGeneratorModel
|
||||
from lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated
|
||||
|
||||
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
||||
from lib.utils.model_io import ModelParameters
|
||||
from lib.utils.transforms import AsArray
|
||||
|
@ -37,7 +37,7 @@ class Logger(LightningLoggerBase):
|
||||
@property
|
||||
def outpath(self):
|
||||
# ToDo: Add further path modification such as dataset config etc.
|
||||
return Path(self.config.train.outpath)
|
||||
return Path(self.config.train.outpath) / self.config.data.mode
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""
|
||||
|
@ -9,6 +9,7 @@ def write_to_shelve(file_path, value):
|
||||
with shelve.open(str(file_path), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
||||
new_key = str(len(f))
|
||||
f[new_key] = value
|
||||
f.close()
|
||||
|
||||
|
||||
def load_from_shelve(file_path, key):
|
||||
|
@ -1,9 +1,15 @@
|
||||
from pathlib import Path
|
||||
_ROOT = Path('..')
|
||||
|
||||
# Labels for classes
|
||||
HOMOTOPIC = 1
|
||||
ALTERNATIVE = 0
|
||||
ANY = -1
|
||||
|
||||
# Colors for img files
|
||||
WHITE = 255
|
||||
BLACK = 0
|
||||
|
||||
DPI = 100
|
||||
# Variables for plotting
|
||||
PADDING = 0.25
|
||||
DPI = 50
|
||||
|
@ -1,53 +1,106 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.cm as cmaps
|
||||
from mpl_toolkits.axisartist.axes_grid import ImageGrid
|
||||
from sklearn.cluster import Birch, DBSCAN, KMeans
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
import lib.variables as V
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GeneratorVisualizer(object):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
||||
self.alternatives = kwargs.get('generated_alternative')
|
||||
self.labels = kwargs.get('labels')
|
||||
self.trajectories = kwargs.get('trajectories')
|
||||
self.maps = kwargs.get('maps')
|
||||
def __init__(self, outputs, k=8):
|
||||
d = defaultdict(list)
|
||||
for key in outputs.keys():
|
||||
try:
|
||||
d[key] = outputs[key][:k].cpu().numpy()
|
||||
except AttributeError:
|
||||
d[key] = outputs[key][:k]
|
||||
except TypeError:
|
||||
self.batch_nb = outputs[key]
|
||||
for key in d.keys():
|
||||
self.__setattr__(key, d[key])
|
||||
|
||||
self._map_width, self._map_height = self.maps[0].squeeze().shape
|
||||
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
||||
self._map_width, self._map_height = self.input.shape[1], self.input.shape[2]
|
||||
self.column_dict_list = self._build_column_dict_list()
|
||||
self._cols = len(self.column_dict_list)
|
||||
self._rows = len(self.column_dict_list[0])
|
||||
|
||||
self.colormap = cmaps.tab20
|
||||
|
||||
def _build_column_dict_list(self):
|
||||
trajectories = []
|
||||
non_hom_alternatives = []
|
||||
hom_alternatives = []
|
||||
alternatives = []
|
||||
|
||||
for idx in range(self.alternatives.shape[0]):
|
||||
image = (self.alternatives[idx]).cpu().numpy().squeeze()
|
||||
label = self.labels[idx].item()
|
||||
# Dirty and Quick hack incomming.
|
||||
if label == V.HOMOTOPIC:
|
||||
hom_alternatives.append(dict(image=image, label='Homotopic'))
|
||||
non_hom_alternatives.append(None)
|
||||
else:
|
||||
non_hom_alternatives.append(dict(image=image, label='NonHomotopic'))
|
||||
hom_alternatives.append(None)
|
||||
for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))):
|
||||
image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze()
|
||||
for idx in range(self.output.shape[0]):
|
||||
image = (self.output[idx]).squeeze()
|
||||
label = 'Homotopic' if self.labels[idx].item() == V.HOMOTOPIC else 'Alternative'
|
||||
alternatives.append(dict(image=image, label=label))
|
||||
|
||||
for idx in range(len(alternatives)):
|
||||
image = (self.input[idx]).squeeze()
|
||||
label = 'original'
|
||||
trajectories.append(dict(image=image, label=label))
|
||||
|
||||
return trajectories, hom_alternatives, non_hom_alternatives
|
||||
return trajectories, alternatives
|
||||
|
||||
def draw(self):
|
||||
padding = 0.25
|
||||
additional_size = self._cols * padding + 3 * padding
|
||||
width = (self._map_width * self._cols) / V.DPI + additional_size
|
||||
height = (self._map_height * self._rows) / V.DPI + additional_size
|
||||
@staticmethod
|
||||
def cluster_data(data):
|
||||
|
||||
cluster = Birch()
|
||||
|
||||
labels = cluster.fit_predict(data)
|
||||
return labels
|
||||
|
||||
def draw_latent(self):
|
||||
plt.close('all')
|
||||
clusterer = KMeans(10)
|
||||
try:
|
||||
labels = clusterer.fit_predict(self.logvar)
|
||||
except ValueError:
|
||||
fig = plt.figure()
|
||||
return fig
|
||||
if self.z.shape[-1] > 2:
|
||||
fig, axs = plt.subplots(ncols=2, nrows=1)
|
||||
transformers = [TSNE(2), PCA(2)]
|
||||
for idx, transformer in enumerate(transformers):
|
||||
transformed = transformer.fit_transform(self.z)
|
||||
|
||||
colored = self.colormap(labels)
|
||||
ax = axs[idx]
|
||||
ax.scatter(x=transformed[:, 0], y=transformed[:, 1], c=colored)
|
||||
ax.set_title(transformer.__class__.__name__)
|
||||
ax.set_xlim(np.min(transformed[:, 0])*1.1, np.max(transformed[:, 0]*1.1))
|
||||
ax.set_ylim(np.min(transformed[:, 1]*1.1), np.max(transformed[:, 1]*1.1))
|
||||
elif self.z.shape[-1] == 2:
|
||||
fig, axs = plt.subplots()
|
||||
|
||||
# TODO: Build transformation for lat_dim_size >= 3
|
||||
print('All Predictions sucesfully Gathered and Shaped ')
|
||||
axs.set_xlim(np.min(self.z[:, 0]), np.max(self.z[:, 0]))
|
||||
axs.set_ylim(np.min(self.z[:, 1]), np.max(self.z[:, 1]))
|
||||
# ToDo: Insert Normalization
|
||||
colored = self.colormap(labels)
|
||||
plt.scatter(self.z[:, 0], self.z[:, 1], c=colored)
|
||||
else:
|
||||
raise NotImplementedError("Latent Dimensions can not be one-dimensional (yet).")
|
||||
|
||||
return fig
|
||||
|
||||
def draw_io_bundle(self):
|
||||
width, height = self._cols * 5, self._rows * 5
|
||||
additional_size = self._cols * V.PADDING + 3 * V.PADDING
|
||||
# width = (self._map_width * self._cols) / V.DPI + additional_size
|
||||
# height = (self._map_height * self._rows) / V.DPI + additional_size
|
||||
fig = plt.figure(figsize=(width, height), dpi=V.DPI)
|
||||
grid = ImageGrid(fig, 111, # similar to subplot(111)
|
||||
nrows_ncols=(self._rows, self._cols),
|
||||
axes_pad=padding, # pad between axes in inch.
|
||||
axes_pad=V.PADDING, # pad between axes in inch.
|
||||
)
|
||||
|
||||
for idx in range(len(grid.axes_all)):
|
||||
|
16
main.py
@ -33,12 +33,13 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||
|
||||
# Data Parameters
|
||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000, help="")
|
||||
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
|
||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
||||
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
||||
|
||||
main_arg_parser.add_argument("--data_mode", type=str, default='ae_no_label_in_map', help="")
|
||||
|
||||
# Transformations
|
||||
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
||||
@ -46,7 +47,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
|
||||
# Transformations
|
||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=20, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
||||
@ -54,9 +55,9 @@ main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0
|
||||
# Model
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 32]", help="")
|
||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="")
|
||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="")
|
||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="")
|
||||
@ -101,7 +102,7 @@ def run_lightning_loop(config_obj):
|
||||
model.init_weights(torch.nn.init.xavier_normal_)
|
||||
if model.name == 'CNNRouteGeneratorDiscriminated':
|
||||
# ToDo: Make this dependent on the used seed
|
||||
path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'version_0')
|
||||
path = logger.outpath / 'classifier_cnn' / 'version_0'
|
||||
disc_model = SavedLightningModels.load_checkpoint(path).restore()
|
||||
model.set_discriminator(disc_model)
|
||||
|
||||
@ -111,13 +112,12 @@ def run_lightning_loop(config_obj):
|
||||
show_progress_bar=True,
|
||||
weights_save_path=logger.log_dir,
|
||||
gpus=[0] if torch.cuda.is_available() else None,
|
||||
check_val_every_n_epoch=1,
|
||||
num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
|
||||
check_val_every_n_epoch=10,
|
||||
# num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
|
||||
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
||||
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
logger=logger,
|
||||
val_percent_check=0.025,
|
||||
fast_dev_run=config_obj.main.debug,
|
||||
early_stop_callback=None
|
||||
)
|
||||
|
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 1.6 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 1.6 KiB |
Before Width: | Height: | Size: 1.6 KiB |
Before Width: | Height: | Size: 831 B |
Before Width: | Height: | Size: 1.6 KiB |
Before Width: | Height: | Size: 1.6 KiB |
Before Width: | Height: | Size: 1.6 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 29 KiB |