Offline Datasets res net optionality
This commit is contained in:
@ -6,7 +6,7 @@ import torch
|
|||||||
from torch.utils.data import Dataset, ConcatDataset
|
from torch.utils.data import Dataset, ConcatDataset
|
||||||
|
|
||||||
from datasets.utils import DatasetMapping
|
from datasets.utils import DatasetMapping
|
||||||
from lib.modules.model_parts import Generator
|
from lib.preprocessing.generator import Generator
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,15 +2,50 @@ import shelve
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
|
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from random import choice
|
from random import choice
|
||||||
from torch.utils.data import ConcatDataset, Dataset
|
from torch.utils.data import ConcatDataset, Dataset
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
import lib.variables as V
|
import lib.variables as V
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from lib.utils.tools import write_to_shelve
|
||||||
|
|
||||||
|
|
||||||
|
class TrajDataShelve(Dataset):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def map_shape(self):
|
||||||
|
return self[0][0].shape
|
||||||
|
|
||||||
|
def __init__(self, file_path, **kwargs):
|
||||||
|
super(TrajDataShelve, self).__init__()
|
||||||
|
self._mutex = mp.Lock()
|
||||||
|
self.file_path = str(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
self._mutex.acquire()
|
||||||
|
with shelve.open(self.file_path) as d:
|
||||||
|
length = len(d)
|
||||||
|
self._mutex.release()
|
||||||
|
return length
|
||||||
|
|
||||||
|
def seed(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
self._mutex.acquire()
|
||||||
|
with shelve.open(self.file_path) as d:
|
||||||
|
sample = d[str(item)]
|
||||||
|
self._mutex.release()
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
class TrajDataset(Dataset):
|
class TrajDataset(Dataset):
|
||||||
|
|
||||||
@ -22,14 +57,15 @@ class TrajDataset(Dataset):
|
|||||||
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
|
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(TrajDataset, self).__init__()
|
super(TrajDataset, self).__init__()
|
||||||
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
|
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map'
|
||||||
|
'classifier_all_in_map']
|
||||||
self.normalized = normalized
|
self.normalized = normalized
|
||||||
self.preserve_equal_samples = preserve_equal_samples
|
self.preserve_equal_samples = preserve_equal_samples
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||||
self.maps_root = maps_root
|
self.maps_root = maps_root
|
||||||
self._len = length
|
self._len = length
|
||||||
self.last_label = -1
|
self.last_label = V.ALTERNATIVE if 'hom' in self.mode else choice([-1, V.ALTERNATIVE, V.HOMOTOPIC])
|
||||||
|
|
||||||
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
||||||
|
|
||||||
@ -39,6 +75,7 @@ class TrajDataset(Dataset):
|
|||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
|
|
||||||
if self.mode.lower() == 'just_route':
|
if self.mode.lower() == 'just_route':
|
||||||
|
raise NotImplementedError
|
||||||
trajectory = self.map.get_random_trajectory()
|
trajectory = self.map.get_random_trajectory()
|
||||||
trajectory_space = trajectory.draw_in_array(self.map.shape)
|
trajectory_space = trajectory.draw_in_array(self.map.shape)
|
||||||
label = choice([0, 1])
|
label = choice([0, 1])
|
||||||
@ -54,37 +91,41 @@ class TrajDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
self.last_label = label
|
self.last_label = label if self.mode != ['generator_hom_all_in_map'] else V.ALTERNATIVE
|
||||||
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
if self.mode.lower() in ['classifier_all_in_map', 'generator_all_in_map']:
|
||||||
map_array = self.map.as_array
|
map_array = self.map.as_array
|
||||||
trajectory = trajectory.draw_in_array(self.map_shape)
|
trajectory = trajectory.draw_in_array(self.map_shape)
|
||||||
alternative = alternative.draw_in_array(self.map_shape)
|
alternative = alternative.draw_in_array(self.map_shape)
|
||||||
if self.mode == 'separated_arrays':
|
label_as_array = np.full_like(map_array, label)
|
||||||
if self.normalized:
|
if self.normalized:
|
||||||
map_array = map_array / V.WHITE
|
map_array = map_array / V.WHITE
|
||||||
trajectory = trajectory / V.WHITE
|
trajectory = trajectory / V.WHITE
|
||||||
alternative = alternative / V.WHITE
|
alternative = alternative / V.WHITE
|
||||||
return (map_array, trajectory, label), alternative
|
if self.mode == 'generator_all_in_map':
|
||||||
else:
|
return np.concatenate((map_array, trajectory, label_as_array)), alternative
|
||||||
|
elif self.mode == 'classifier_all_in_map':
|
||||||
return np.concatenate((map_array, trajectory, alternative)), label
|
return np.concatenate((map_array, trajectory, alternative)), label
|
||||||
|
|
||||||
elif self.mode == 'vectors':
|
elif self.mode == '_vectors':
|
||||||
|
raise NotImplementedError
|
||||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||||
|
|
||||||
else:
|
raise ValueError(f'Mode was: {self.mode}')
|
||||||
raise ValueError
|
|
||||||
|
def seed(self, seed):
|
||||||
|
self.map.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
class TrajData(object):
|
class TrajData(object):
|
||||||
@property
|
@property
|
||||||
def map_shapes(self):
|
def map_shapes(self):
|
||||||
return [dataset.map_shape for dataset in self._dataset.datasets]
|
return [dataset.map_shape for dataset in self._train_dataset.datasets]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def map_shapes_max(self):
|
def map_shapes_max(self):
|
||||||
shapes = self.map_shapes
|
shapes = self.map_shapes
|
||||||
shape_list = list(map(max, zip(*shapes)))
|
shape_list = list(map(max, zip(*shapes)))
|
||||||
if self.mode in ['separated_arrays', 'all_in_map']:
|
if '_all_in_map' in self.mode:
|
||||||
shape_list[0] += 2
|
shape_list[0] += 2
|
||||||
return shape_list
|
return shape_list
|
||||||
|
|
||||||
@ -92,36 +133,81 @@ class TrajData(object):
|
|||||||
def name(self):
|
def name(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, **_):
|
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, preprocessed=False, **_):
|
||||||
|
self.preprocessed = preprocessed
|
||||||
self.normalized = normalized
|
self.normalized = normalized
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.maps_root = Path(map_root)
|
self.maps_root = Path(map_root)
|
||||||
self.length = length
|
self.length = length
|
||||||
self._dataset = self._load_datasets()
|
self._test_dataset = self._load_datasets('train')
|
||||||
|
self._val_dataset = self._load_datasets('val')
|
||||||
|
self._train_dataset = self._load_datasets('test')
|
||||||
|
|
||||||
|
def _load_datasets(self, dataset_type=''):
|
||||||
|
|
||||||
def _load_datasets(self):
|
|
||||||
map_files = list(self.maps_root.glob('*.bmp'))
|
map_files = list(self.maps_root.glob('*.bmp'))
|
||||||
equal_split = int(self.length // len(map_files)) or 1
|
equal_split = int(self.length // len(map_files)) or 1
|
||||||
|
|
||||||
# find max image size among available maps:
|
# find max image size among available maps:
|
||||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||||
|
|
||||||
|
if self.preprocessed:
|
||||||
|
preprocessed_map_files = list(self.maps_root.glob('*.pik'))
|
||||||
|
preprocessed_map_names = [p.name for p in preprocessed_map_files]
|
||||||
|
datasets = []
|
||||||
|
for map_file in map_files:
|
||||||
|
new_pik_name = f'{dataset_type}_{str(map_file.name)[:-3]}.pik'
|
||||||
|
if dataset_type != 'train':
|
||||||
|
equal_split *= 0.01
|
||||||
|
if not [f'{new_pik_name[:-3]}.bmp' in preprocessed_map_names]:
|
||||||
|
traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||||
|
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
||||||
|
preserve_equal_samples=True)
|
||||||
|
self.dump_n(map_file.parent / new_pik_name, traj_dataset, n=equal_split)
|
||||||
|
|
||||||
|
dataset = TrajDataShelve(map_file.parent / new_pik_name)
|
||||||
|
datasets.append(dataset)
|
||||||
|
return ConcatDataset(datasets)
|
||||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||||
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
||||||
preserve_equal_samples=True)
|
preserve_equal_samples=True)
|
||||||
for map_file in map_files])
|
for map_file in map_files])
|
||||||
|
|
||||||
|
def kill_em_all(self):
|
||||||
|
for pik_file in self.maps_root.glob('*.pik'):
|
||||||
|
pik_file.unlink()
|
||||||
|
print(pik_file.name, ' was deleted.')
|
||||||
|
print('Done.')
|
||||||
|
|
||||||
|
def seed(self, seed):
|
||||||
|
for dataset in [x.datasets for x in [self._train_dataset, self._test_dataset, self.val_dataset]]:
|
||||||
|
dataset.seed(seed)
|
||||||
|
|
||||||
|
def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000):
|
||||||
|
assert str(file_path).endswith('.pik')
|
||||||
|
processes = mp.cpu_count() - 1
|
||||||
|
mutex = mp.Lock()
|
||||||
|
with mp.Pool(processes) as pool:
|
||||||
|
async_results = [pool.apply_async(traj_dataset.__getitem__, kwds=dict(item=i)) for i in range(n)]
|
||||||
|
|
||||||
|
for result_obj in tqdm(async_results, total=n, desc=f'Generating {n} Samples'):
|
||||||
|
sample = result_obj.get()
|
||||||
|
mutex.acquire()
|
||||||
|
write_to_shelve(file_path, sample)
|
||||||
|
mutex.release()
|
||||||
|
print(f'{n} samples sucessfully dumped to "{file_path}"!')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def train_dataset(self):
|
def train_dataset(self):
|
||||||
return self._dataset
|
return self._train_dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def val_dataset(self):
|
def val_dataset(self):
|
||||||
return self._dataset
|
return self._val_dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def test_dataset(self):
|
def test_dataset(self):
|
||||||
return self._dataset
|
return self._test_dataset
|
||||||
|
|
||||||
def get_datasets(self):
|
def get_datasets(self):
|
||||||
return self._dataset, self._dataset, self._dataset
|
return self._train_dataset, self._val_dataset, self._test_dataset
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from random import choice
|
from random import choices, seed
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
@ -36,28 +37,36 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
# kld_loss /= reduce(mul, self.in_shape)
|
# kld_loss /= reduce(mul, self.in_shape)
|
||||||
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
|
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
|
||||||
|
|
||||||
loss = (kld_loss + element_wise_loss) / 2
|
loss = kld_loss + element_wise_loss
|
||||||
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
||||||
|
|
||||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||||
batch_x, _ = batch_xy
|
batch_x, _ = batch_xy
|
||||||
map_array, trajectory, label = batch_x
|
map_array = batch_x[:, 0].unsqueeze(1)
|
||||||
|
trajectory = batch_x[:, 1].unsqueeze(1)
|
||||||
|
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
|
||||||
|
|
||||||
generated_alternative, z, mu, logvar = self(batch_x)
|
_, mu, _ = self.encode(batch_x)
|
||||||
|
generated_alternative = self.generate(mu)
|
||||||
return dict(batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1)
|
return dict(maps=map_array, trajectories=trajectory, batch_nb=batch_nb, labels=labels,
|
||||||
|
generated_alternative=generated_alternative, pred_label=-1)
|
||||||
|
|
||||||
def _test_val_epoch_end(self, outputs, test=False):
|
def _test_val_epoch_end(self, outputs, test=False):
|
||||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
val_restul_dict = self.generate_random()
|
||||||
|
|
||||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
g = GeneratorVisualizer(**val_restul_dict)
|
||||||
fig = g.draw()
|
fig = g.draw()
|
||||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||||
plt.clf()
|
plt.clf()
|
||||||
|
|
||||||
return dict(epoch=self.current_epoch)
|
return dict(epoch=self.current_epoch)
|
||||||
|
|
||||||
|
def on_epoch_start(self):
|
||||||
|
self.dataset.seed(self.logger.version)
|
||||||
|
# torch.random.manual_seed(self.logger.version)
|
||||||
|
# np.random.seed(self.logger.version)
|
||||||
|
|
||||||
def validation_step(self, *args):
|
def validation_step(self, *args):
|
||||||
return self._test_val_step(*args)
|
return self._test_val_step(*args)
|
||||||
|
|
||||||
@ -75,14 +84,18 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
|
|
||||||
if not issubclassed:
|
if not issubclassed:
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays',
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='generator_all_in_map',
|
||||||
|
preprocessed=self.hparams.data_param.use_preprocessed,
|
||||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||||
self.criterion = nn.MSELoss()
|
self.criterion = nn.MSELoss()
|
||||||
|
|
||||||
# Additional Attributes
|
# Additional Attributes #
|
||||||
|
#######################################################
|
||||||
self.in_shape = self.dataset.map_shapes_max
|
self.in_shape = self.dataset.map_shapes_max
|
||||||
# Todo: Better naming and size in Parameters
|
self.use_res_net = self.hparams.model_param.use_res_net
|
||||||
self.feature_dim = self.hparams.model_param.lat_dim * 10
|
self.lat_dim = self.hparams.model_param.lat_dim
|
||||||
|
self.feature_dim = self.lat_dim * 10
|
||||||
|
########################################################
|
||||||
|
|
||||||
# NN Nodes
|
# NN Nodes
|
||||||
###################################################
|
###################################################
|
||||||
@ -93,82 +106,100 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
|
|
||||||
#
|
#
|
||||||
# Map Encoder
|
# Map Encoder
|
||||||
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
self.enc_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
||||||
conv_filters=self.hparams.model_param.filters[0],
|
conv_filters=self.hparams.model_param.filters[0],
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
|
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[1],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=2, conv_padding=0,
|
||||||
conv_filters=self.hparams.model_param.filters[1],
|
conv_filters=self.hparams.model_param.filters[1],
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
|
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[2],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||||
conv_filters=self.hparams.model_param.filters[2],
|
conv_filters=self.hparams.model_param.filters[2],
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=7, conv_stride=1,
|
self.enc_res_3 = ResidualModule(self.enc_conv_2b.shape, ConvModule, 2, conv_kernel=7, conv_stride=1,
|
||||||
conv_padding=3, conv_filters=self.hparams.model_param.filters[2],
|
conv_padding=3, conv_filters=self.hparams.model_param.filters[2],
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=11, conv_stride=1, conv_padding=0,
|
self.enc_conv_3a = ConvModule(self.enc_res_3.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[2],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
self.enc_conv_3b = ConvModule(self.enc_conv_3a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||||
conv_filters=self.hparams.model_param.filters[2],
|
conv_filters=self.hparams.model_param.filters[2],
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.map_flat = Flatten(self.map_conv_3.shape)
|
self.enc_flat = Flatten(self.enc_conv_3b.shape)
|
||||||
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
|
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Mixed Encoder
|
# Mixed Encoder
|
||||||
self.mixed_lin = nn.Linear(self.feature_dim, self.feature_dim)
|
self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim)
|
||||||
self.mixed_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
|
self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
|
||||||
|
|
||||||
#
|
#
|
||||||
# Variational Bottleneck
|
# Variational Bottleneck
|
||||||
self.mu = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim)
|
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
||||||
self.logvar = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim)
|
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Alternative Generator
|
# Alternative Generator
|
||||||
self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
|
self.gen_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
|
||||||
# Todo Fix This Hack!!!!
|
|
||||||
reshape_shape = (1, self.map_conv_3.shape[1], self.map_conv_3.shape[2])
|
|
||||||
|
|
||||||
self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, reshape_shape))
|
self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
|
||||||
|
|
||||||
self.reshape_to_map = Flatten(reduce(mul, reshape_shape), reshape_shape)
|
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, self.enc_conv_3b.shape)
|
||||||
|
|
||||||
self.alt_deconv_1 = DeConvModule(reshape_shape, self.hparams.model_param.filters[2],
|
self.gen_deconv_1a = DeConvModule(self.enc_conv_3b.shape, self.hparams.model_param.filters[2],
|
||||||
conv_padding=0, conv_kernel=13, conv_stride=1,
|
conv_padding=0, conv_kernel=11, conv_stride=1,
|
||||||
use_norm=self.hparams.model_param.use_norm)
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
|
self.gen_deconv_1b = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[2],
|
||||||
|
conv_padding=0, conv_kernel=9, conv_stride=2,
|
||||||
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
|
|
||||||
|
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1b.shape, self.hparams.model_param.filters[1],
|
||||||
conv_padding=0, conv_kernel=7, conv_stride=1,
|
conv_padding=0, conv_kernel=7, conv_stride=1,
|
||||||
use_norm=self.hparams.model_param.use_norm)
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
|
self.gen_deconv_2b = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[1],
|
||||||
|
conv_padding=0, conv_kernel=7, conv_stride=1,
|
||||||
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
|
|
||||||
|
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2b.shape, self.hparams.model_param.filters[0],
|
||||||
conv_padding=1, conv_kernel=5, conv_stride=1,
|
conv_padding=1, conv_kernel=5, conv_stride=1,
|
||||||
use_norm=self.hparams.model_param.use_norm)
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
|
self.gen_deconv_3b = DeConvModule(self.gen_deconv_3a.shape, self.hparams.model_param.filters[0],
|
||||||
conv_padding=1, conv_kernel=3, conv_stride=1,
|
conv_padding=1, conv_kernel=4, conv_stride=1,
|
||||||
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
|
|
||||||
|
self.gen_deconv_out = DeConvModule(self.gen_deconv_3b.shape, 1, activation=None,
|
||||||
|
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||||
use_norm=self.hparams.model_param.use_norm)
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
|
|
||||||
def forward(self, batch_x):
|
def forward(self, batch_x):
|
||||||
#
|
|
||||||
# Sorting the Input
|
|
||||||
map_array, trajectory, label = batch_x
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Encode
|
# Encode
|
||||||
z, mu, logvar = self.encode(map_array, trajectory, label)
|
z, mu, logvar = self.encode(batch_x)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Generate
|
# Generate
|
||||||
@ -181,42 +212,26 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
eps = torch.randn_like(std)
|
eps = torch.randn_like(std)
|
||||||
return mu + eps * std
|
return mu + eps * std
|
||||||
|
|
||||||
def generate(self, z):
|
def encode(self, batch_x):
|
||||||
alt_tensor = self.alt_lin_1(z)
|
combined_tensor = self.enc_conv_0(batch_x)
|
||||||
alt_tensor = self.activation(alt_tensor)
|
combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor
|
||||||
alt_tensor = self.alt_lin_2(alt_tensor)
|
combined_tensor = self.enc_conv_1a(combined_tensor)
|
||||||
alt_tensor = self.activation(alt_tensor)
|
combined_tensor = self.enc_conv_1b(combined_tensor)
|
||||||
alt_tensor = self.reshape_to_map(alt_tensor)
|
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
|
||||||
alt_tensor = self.alt_deconv_1(alt_tensor)
|
combined_tensor = self.enc_conv_2a(combined_tensor)
|
||||||
alt_tensor = self.alt_deconv_2(alt_tensor)
|
combined_tensor = self.enc_conv_2b(combined_tensor)
|
||||||
alt_tensor = self.alt_deconv_3(alt_tensor)
|
combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
|
||||||
alt_tensor = self.alt_deconv_out(alt_tensor)
|
combined_tensor = self.enc_conv_3a(combined_tensor)
|
||||||
# alt_tensor = self.activation(alt_tensor)
|
combined_tensor = self.enc_conv_3b(combined_tensor)
|
||||||
alt_tensor = self.sigmoid(alt_tensor)
|
|
||||||
return alt_tensor
|
|
||||||
|
|
||||||
def encode(self, map_array, trajectory, label):
|
combined_tensor = self.enc_flat(combined_tensor)
|
||||||
label_array = torch.cat([torch.full((1, 1, self.in_shape[1], self.in_shape[2]), x.item())
|
combined_tensor = self.enc_lin_1(combined_tensor)
|
||||||
for x in label], dim=0)
|
combined_tensor = self.enc_lin_2(combined_tensor)
|
||||||
label_array = self._move_to_model_device(label_array)
|
|
||||||
combined_tensor = torch.cat((map_array, trajectory, label_array), dim=1)
|
|
||||||
combined_tensor = self.map_conv_0(combined_tensor)
|
|
||||||
combined_tensor = self.map_res_1(combined_tensor)
|
|
||||||
combined_tensor = self.map_conv_1(combined_tensor)
|
|
||||||
combined_tensor = self.map_res_2(combined_tensor)
|
|
||||||
combined_tensor = self.map_conv_2(combined_tensor)
|
|
||||||
combined_tensor = self.map_res_3(combined_tensor)
|
|
||||||
combined_tensor = self.map_conv_3(combined_tensor)
|
|
||||||
|
|
||||||
combined_tensor = self.map_flat(combined_tensor)
|
combined_tensor = self.enc_norm(combined_tensor)
|
||||||
combined_tensor = self.map_lin(combined_tensor)
|
|
||||||
|
|
||||||
combined_tensor = self.mixed_lin(combined_tensor)
|
|
||||||
|
|
||||||
combined_tensor = self.mixed_norm(combined_tensor)
|
|
||||||
combined_tensor = self.activation(combined_tensor)
|
combined_tensor = self.activation(combined_tensor)
|
||||||
combined_tensor = self.mixed_lin(combined_tensor)
|
combined_tensor = self.enc_lin_2(combined_tensor)
|
||||||
combined_tensor = self.mixed_norm(combined_tensor)
|
combined_tensor = self.enc_norm(combined_tensor)
|
||||||
combined_tensor = self.activation(combined_tensor)
|
combined_tensor = self.activation(combined_tensor)
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -226,19 +241,31 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
z = self.reparameterize(mu, logvar)
|
z = self.reparameterize(mu, logvar)
|
||||||
return z, mu, logvar
|
return z, mu, logvar
|
||||||
|
|
||||||
def generate_random(self, n=6):
|
def generate(self, z):
|
||||||
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
|
alt_tensor = self.gen_lin_1(z)
|
||||||
|
alt_tensor = self.activation(alt_tensor)
|
||||||
|
alt_tensor = self.gen_lin_2(alt_tensor)
|
||||||
|
alt_tensor = self.activation(alt_tensor)
|
||||||
|
alt_tensor = self.reshape_to_last_conv(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_1a(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_1b(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_2a(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_2b(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_3a(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_3b(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_out(alt_tensor)
|
||||||
|
# alt_tensor = self.activation(alt_tensor)
|
||||||
|
alt_tensor = self.sigmoid(alt_tensor)
|
||||||
|
return alt_tensor
|
||||||
|
|
||||||
trajectories = [x.get_random_trajectory() for x in maps]
|
def generate_random(self, n=12):
|
||||||
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
|
|
||||||
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2
|
|
||||||
trajectories = self._move_to_model_device(torch.stack(trajectories))
|
|
||||||
|
|
||||||
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
|
samples, alternatives = zip(*[self.dataset.test_dataset[choice]
|
||||||
maps = self._move_to_model_device(torch.stack(maps))
|
for choice in choices(range(self.dataset.length), k=n)])
|
||||||
|
samples = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in samples]))
|
||||||
|
alternatives = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in alternatives]))
|
||||||
|
|
||||||
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
|
return self._test_val_step((samples, alternatives), -9999)
|
||||||
return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999)
|
|
||||||
|
|
||||||
|
|
||||||
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||||
@ -329,11 +356,12 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
|||||||
self._disc = disc_model
|
self._disc = disc_model
|
||||||
|
|
||||||
def __init__(self, *params):
|
def __init__(self, *params):
|
||||||
|
raise NotImplementedError
|
||||||
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
|
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
|
||||||
|
|
||||||
self._disc = None
|
self._disc = None
|
||||||
|
|
||||||
self.criterion = nn.BCELoss()
|
self.criterion = nn.BCELoss()
|
||||||
|
|
||||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True,
|
||||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||||
|
@ -0,0 +1,348 @@
|
|||||||
|
from random import choice
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from functools import reduce
|
||||||
|
from operator import mul
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
|
from datasets.trajectory_dataset import TrajData
|
||||||
|
from lib.evaluation.classification import ROCEvaluation
|
||||||
|
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||||
|
from lib.modules.utils import LightningBaseModule, Flatten
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||||
|
|
||||||
|
name = 'CNNRouteGenerator'
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
||||||
|
|
||||||
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
|
batch_x, alternative = batch_xy
|
||||||
|
generated_alternative, z, mu, logvar = self(batch_x)
|
||||||
|
element_wise_loss = self.criterion(generated_alternative, alternative)
|
||||||
|
# see Appendix B from VAE paper:
|
||||||
|
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||||
|
# https://arxiv.org/abs/1312.6114
|
||||||
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||||
|
|
||||||
|
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
|
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||||
|
# kld_loss /= reduce(mul, self.in_shape)
|
||||||
|
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
|
||||||
|
|
||||||
|
loss = (kld_loss + element_wise_loss) / 2
|
||||||
|
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
||||||
|
|
||||||
|
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||||
|
batch_x, alternative = batch_xy
|
||||||
|
map_array = batch_x[0]
|
||||||
|
trajectory = batch_x[1]
|
||||||
|
label = batch_x[2].max()
|
||||||
|
|
||||||
|
z, _, _ = self.encode(batch_x)
|
||||||
|
generated_alternative = self.generate(z)
|
||||||
|
|
||||||
|
return dict(map_array=map_array, trajectory=trajectory, batch_nb=batch_nb, label=label,
|
||||||
|
generated_alternative=generated_alternative, pred_label=-1, alternative=alternative
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_val_epoch_end(self, outputs, test=False):
|
||||||
|
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||||
|
|
||||||
|
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||||
|
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||||
|
fig = g.draw()
|
||||||
|
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
return dict(epoch=self.current_epoch)
|
||||||
|
|
||||||
|
def validation_step(self, *args):
|
||||||
|
return self._test_val_step(*args)
|
||||||
|
|
||||||
|
def validation_epoch_end(self, outputs: list):
|
||||||
|
return self._test_val_epoch_end(outputs)
|
||||||
|
|
||||||
|
def test_step(self, *args):
|
||||||
|
return self._test_val_step(*args)
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
return self._test_val_epoch_end(outputs, test=True)
|
||||||
|
|
||||||
|
def __init__(self, *params, issubclassed=False):
|
||||||
|
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||||
|
|
||||||
|
if not issubclassed:
|
||||||
|
# Dataset
|
||||||
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='generator_all_in_map',
|
||||||
|
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||||
|
self.criterion = nn.MSELoss()
|
||||||
|
|
||||||
|
# Additional Attributes #
|
||||||
|
#######################################################
|
||||||
|
self.map_shape = self.dataset.map_shapes_max
|
||||||
|
self.trajectory_features = 4
|
||||||
|
self.res_net = self.hparams.model_param.use_res_net
|
||||||
|
self.lat_dim = self.hparams.model_param.lat_dim
|
||||||
|
self.feature_dim = self.lat_dim * 10
|
||||||
|
########################################################
|
||||||
|
|
||||||
|
# NN Nodes
|
||||||
|
###################################################
|
||||||
|
#
|
||||||
|
# Utils
|
||||||
|
self.activation = nn.ReLU()
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
#
|
||||||
|
# Map Encoder
|
||||||
|
self.enc_conv_0 = ConvModule(self.map_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
||||||
|
conv_filters=self.hparams.model_param.filters[0],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
|
self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||||
|
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[1],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=2, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[1],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
|
self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||||
|
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[2],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[2],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
|
self.enc_res_3 = ResidualModule(self.enc_conv_2b.shape, ConvModule, 2, conv_kernel=7, conv_stride=1,
|
||||||
|
conv_padding=3, conv_filters=self.hparams.model_param.filters[2],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
self.enc_conv_3a = ConvModule(self.enc_res_3.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[2],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
self.enc_conv_3b = ConvModule(self.enc_conv_3a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[2],
|
||||||
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
|
# Trajectory Encoder
|
||||||
|
self.env_gru_1 = nn.GRU(input_size=self.trajectory_features, hidden_size=self.feature_dim,
|
||||||
|
num_layers=3, batch_first=True)
|
||||||
|
|
||||||
|
self.enc_flat = Flatten(self.enc_conv_3b.shape)
|
||||||
|
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Mixed Encoder
|
||||||
|
self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim)
|
||||||
|
self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
|
||||||
|
|
||||||
|
#
|
||||||
|
# Variational Bottleneck
|
||||||
|
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
||||||
|
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Alternative Generator
|
||||||
|
self.gen_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
|
||||||
|
|
||||||
|
self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
|
||||||
|
|
||||||
|
self.gen_gru_x = nn.GRU(None, None, batch_first=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, batch_x):
|
||||||
|
#
|
||||||
|
# Encode
|
||||||
|
z, mu, logvar = self.encode(batch_x)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Generate
|
||||||
|
alt_tensor = self.generate(z)
|
||||||
|
return alt_tensor, z, mu, logvar
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reparameterize(mu, logvar):
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
eps = torch.randn_like(std)
|
||||||
|
return mu + eps * std
|
||||||
|
|
||||||
|
def encode(self, batch_x):
|
||||||
|
combined_tensor = self.enc_conv_0(batch_x)
|
||||||
|
combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor
|
||||||
|
combined_tensor = self.enc_conv_1a(combined_tensor)
|
||||||
|
combined_tensor = self.enc_conv_1b(combined_tensor)
|
||||||
|
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
|
||||||
|
combined_tensor = self.enc_conv_2a(combined_tensor)
|
||||||
|
combined_tensor = self.enc_conv_2b(combined_tensor)
|
||||||
|
combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
|
||||||
|
combined_tensor = self.enc_conv_3a(combined_tensor)
|
||||||
|
combined_tensor = self.enc_conv_3b(combined_tensor)
|
||||||
|
|
||||||
|
combined_tensor = self.enc_flat(combined_tensor)
|
||||||
|
combined_tensor = self.enc_lin_1(combined_tensor)
|
||||||
|
combined_tensor = self.enc_lin_2(combined_tensor)
|
||||||
|
|
||||||
|
combined_tensor = self.enc_norm(combined_tensor)
|
||||||
|
combined_tensor = self.activation(combined_tensor)
|
||||||
|
combined_tensor = self.enc_lin_2(combined_tensor)
|
||||||
|
combined_tensor = self.enc_norm(combined_tensor)
|
||||||
|
combined_tensor = self.activation(combined_tensor)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Parameter and Sampling
|
||||||
|
mu = self.mu(combined_tensor)
|
||||||
|
logvar = self.logvar(combined_tensor)
|
||||||
|
z = self.reparameterize(mu, logvar)
|
||||||
|
return z, mu, logvar
|
||||||
|
|
||||||
|
def generate(self, z):
|
||||||
|
alt_tensor = self.gen_lin_1(z)
|
||||||
|
alt_tensor = self.activation(alt_tensor)
|
||||||
|
alt_tensor = self.gen_lin_2(alt_tensor)
|
||||||
|
alt_tensor = self.activation(alt_tensor)
|
||||||
|
alt_tensor = self.reshape_to_last_conv(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_1a(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_1b(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_2a(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_2b(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_3a(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_3b(alt_tensor)
|
||||||
|
alt_tensor = self.gen_deconv_out(alt_tensor)
|
||||||
|
# alt_tensor = self.activation(alt_tensor)
|
||||||
|
alt_tensor = self.sigmoid(alt_tensor)
|
||||||
|
return alt_tensor
|
||||||
|
|
||||||
|
def generate_random(self, n=6):
|
||||||
|
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
|
||||||
|
|
||||||
|
trajectories = [x.get_random_trajectory() for x in maps]
|
||||||
|
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
|
||||||
|
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2
|
||||||
|
trajectories = self._move_to_model_device(torch.stack(trajectories))
|
||||||
|
|
||||||
|
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
|
||||||
|
maps = self._move_to_model_device(torch.stack(maps))
|
||||||
|
|
||||||
|
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
|
||||||
|
return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999)
|
||||||
|
|
||||||
|
|
||||||
|
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||||
|
|
||||||
|
name = 'CNNRouteGeneratorDiscriminated'
|
||||||
|
|
||||||
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
|
batch_x, label = batch_xy
|
||||||
|
|
||||||
|
generated_alternative, z, mu, logvar = self(batch_x)
|
||||||
|
map_array, trajectory = batch_x
|
||||||
|
|
||||||
|
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||||
|
pred_label = self.discriminator(map_stack)
|
||||||
|
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||||
|
|
||||||
|
# see Appendix B from VAE paper:
|
||||||
|
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||||
|
# https://arxiv.org/abs/1312.6114
|
||||||
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||||
|
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
|
# Dimensional Resizing
|
||||||
|
kld_loss /= reduce(mul, self.in_shape)
|
||||||
|
|
||||||
|
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||||
|
return dict(loss=loss, log=dict(loss=loss,
|
||||||
|
discriminated_bce_loss=discriminated_bce_loss,
|
||||||
|
kld_loss=kld_loss)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||||
|
batch_x, label = batch_xy
|
||||||
|
|
||||||
|
generated_alternative, z, mu, logvar = self(batch_x)
|
||||||
|
map_array, trajectory = batch_x
|
||||||
|
|
||||||
|
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||||
|
pred_label = self.discriminator(map_stack)
|
||||||
|
|
||||||
|
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||||
|
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
|
||||||
|
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
|
||||||
|
|
||||||
|
def validation_step(self, *args):
|
||||||
|
return self._test_val_step(*args)
|
||||||
|
|
||||||
|
def validation_epoch_end(self, outputs: list):
|
||||||
|
return self._test_val_epoch_end(outputs)
|
||||||
|
|
||||||
|
def _test_val_epoch_end(self, outputs, test=False):
|
||||||
|
evaluation = ROCEvaluation(plot_roc=True)
|
||||||
|
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||||
|
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||||
|
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
||||||
|
|
||||||
|
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||||
|
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||||
|
if test:
|
||||||
|
# self.logger.log_metrics(score_dict)
|
||||||
|
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||||
|
|
||||||
|
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||||
|
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||||
|
fig = g.draw()
|
||||||
|
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
||||||
|
|
||||||
|
def test_step(self, *args):
|
||||||
|
return self._test_val_step(*args)
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
return self._test_val_epoch_end(outputs, test=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def discriminator(self):
|
||||||
|
if self._disc is None:
|
||||||
|
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
|
||||||
|
return self._disc
|
||||||
|
|
||||||
|
def set_discriminator(self, disc_model):
|
||||||
|
if self._disc is not None:
|
||||||
|
raise RuntimeError('Discriminator has already been set... What are trying to do?')
|
||||||
|
self._disc = disc_model
|
||||||
|
|
||||||
|
def __init__(self, *params):
|
||||||
|
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
|
||||||
|
|
||||||
|
self._disc = None
|
||||||
|
|
||||||
|
self.criterion = nn.BCELoss()
|
||||||
|
|
||||||
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
||||||
|
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||||
|
@ -60,7 +60,7 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
super(ConvHomDetector, self).__init__(hparams)
|
super(ConvHomDetector, self).__init__(hparams)
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map', )
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='classifier_all_in_map', )
|
||||||
|
|
||||||
# Additional Attributes
|
# Additional Attributes
|
||||||
self.map_shape = self.dataset.map_shapes_max
|
self.map_shape = self.dataset.map_shapes_max
|
||||||
|
@ -22,7 +22,7 @@ class Flatten(nn.Module):
|
|||||||
try:
|
try:
|
||||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
output = self(x)
|
output = self(x)
|
||||||
return output.shape[1:]
|
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
return -1
|
return -1
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import shelve
|
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
from random import choice
|
from random import Random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -53,8 +52,12 @@ class Map(object):
|
|||||||
assert array_like_map_representation.ndim == 3
|
assert array_like_map_representation.ndim == 3
|
||||||
self.map_array: np.ndarray = array_like_map_representation
|
self.map_array: np.ndarray = array_like_map_representation
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.prng = Random()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def seed(self, seed):
|
||||||
|
self.prng.seed(seed)
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
super(Map, self).__setattr__(key, value)
|
super(Map, self).__setattr__(key, value)
|
||||||
if key == 'map_array' and self.map_array is not None:
|
if key == 'map_array' and self.map_array is not None:
|
||||||
@ -102,7 +105,7 @@ class Map(object):
|
|||||||
return trajectory
|
return trajectory
|
||||||
|
|
||||||
def get_valid_position(self):
|
def get_valid_position(self):
|
||||||
valid_position = choice(list(self._G.nodes))
|
valid_position = self.prng.choice(list(self._G.nodes))
|
||||||
return valid_position
|
return valid_position
|
||||||
|
|
||||||
def get_trajectory_from_vertices(self, *args):
|
def get_trajectory_from_vertices(self, *args):
|
||||||
|
@ -20,6 +20,8 @@ class Generator:
|
|||||||
|
|
||||||
self.data_root = Path(data_root)
|
self.data_root = Path(data_root)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
|
def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
|
||||||
datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
|
datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
|
||||||
kwargs.update(n=m)
|
kwargs.update(n=m)
|
||||||
|
22
lib/utils/tools.py
Normal file
22
lib/utils/tools.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import pickle
|
||||||
|
import shelve
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def write_to_shelve(file_path, value):
|
||||||
|
check_path(file_path)
|
||||||
|
file_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
with shelve.open(str(file_path), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
||||||
|
new_key = str(len(f))
|
||||||
|
f[new_key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_shelve(file_path, key):
|
||||||
|
check_path(file_path)
|
||||||
|
with shelve.open(str(file_path)) as d:
|
||||||
|
return d[key]
|
||||||
|
|
||||||
|
|
||||||
|
def check_path(file_path):
|
||||||
|
assert isinstance(file_path, Path)
|
||||||
|
assert str(file_path).endswith('.pik')
|
@ -5,12 +5,13 @@ import lib.variables as V
|
|||||||
|
|
||||||
class GeneratorVisualizer(object):
|
class GeneratorVisualizer(object):
|
||||||
|
|
||||||
def __init__(self, maps, trajectories, labels, val_result_dict):
|
def __init__(self, **kwargs):
|
||||||
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
|
||||||
self.alternatives = val_result_dict['generated_alternative']
|
self.alternatives = kwargs.get('generated_alternative')
|
||||||
self.labels = labels
|
self.labels = kwargs.get('labels')
|
||||||
self.trajectories = trajectories
|
self.trajectories = kwargs.get('trajectories')
|
||||||
self.maps = maps
|
self.maps = kwargs.get('maps')
|
||||||
|
|
||||||
self._map_width, self._map_height = self.maps[0].squeeze().shape
|
self._map_width, self._map_height = self.maps[0].squeeze().shape
|
||||||
self.column_dict_list = self._build_column_dict_list()
|
self.column_dict_list = self._build_column_dict_list()
|
||||||
self._cols = len(self.column_dict_list)
|
self._cols = len(self.column_dict_list)
|
||||||
@ -24,10 +25,13 @@ class GeneratorVisualizer(object):
|
|||||||
for idx in range(self.alternatives.shape[0]):
|
for idx in range(self.alternatives.shape[0]):
|
||||||
image = (self.alternatives[idx]).cpu().numpy().squeeze()
|
image = (self.alternatives[idx]).cpu().numpy().squeeze()
|
||||||
label = self.labels[idx].item()
|
label = self.labels[idx].item()
|
||||||
|
# Dirty and Quick hack incomming.
|
||||||
if label == V.HOMOTOPIC:
|
if label == V.HOMOTOPIC:
|
||||||
hom_alternatives.append(dict(image=image, label='Homotopic'))
|
hom_alternatives.append(dict(image=image, label='Homotopic'))
|
||||||
|
non_hom_alternatives.append(None)
|
||||||
else:
|
else:
|
||||||
non_hom_alternatives.append(dict(image=image, label='NonHomotopic'))
|
non_hom_alternatives.append(dict(image=image, label='NonHomotopic'))
|
||||||
|
hom_alternatives.append(None)
|
||||||
for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))):
|
for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))):
|
||||||
image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze()
|
image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze()
|
||||||
label = 'original'
|
label = 'original'
|
||||||
@ -48,10 +52,13 @@ class GeneratorVisualizer(object):
|
|||||||
|
|
||||||
for idx in range(len(grid.axes_all)):
|
for idx in range(len(grid.axes_all)):
|
||||||
row, col = divmod(idx, len(self.column_dict_list))
|
row, col = divmod(idx, len(self.column_dict_list))
|
||||||
|
if self.column_dict_list[col][row] is not None:
|
||||||
current_image = self.column_dict_list[col][row]['image']
|
current_image = self.column_dict_list[col][row]['image']
|
||||||
current_label = self.column_dict_list[col][row]['label']
|
current_label = self.column_dict_list[col][row]['label']
|
||||||
grid[idx].imshow(current_image)
|
grid[idx].imshow(current_image)
|
||||||
grid[idx].title.set_text(current_label)
|
grid[idx].title.set_text(current_label)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
fig.cbar_mode = 'single'
|
fig.cbar_mode = 'single'
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
return fig
|
return fig
|
||||||
|
6
main.py
6
main.py
@ -37,6 +37,7 @@ main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000,
|
|||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
||||||
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
|
||||||
|
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
||||||
|
|
||||||
|
|
||||||
# Transformations
|
# Transformations
|
||||||
@ -55,9 +56,10 @@ main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerato
|
|||||||
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
|
||||||
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
|
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
|
||||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="")
|
main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="")
|
||||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
||||||
|
main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
||||||
|
|
||||||
# Project
|
# Project
|
||||||
@ -115,7 +117,7 @@ def run_lightning_loop(config_obj):
|
|||||||
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
||||||
checkpoint_callback=checkpoint_callback,
|
checkpoint_callback=checkpoint_callback,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
val_percent_check=0.05,
|
val_percent_check=0.025,
|
||||||
fast_dev_run=config_obj.main.debug,
|
fast_dev_run=config_obj.main.debug,
|
||||||
early_stop_callback=None
|
early_stop_callback=None
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user