VAE Debugging of Route Generator

This commit is contained in:
Si11ium
2020-04-08 08:53:20 +02:00
parent 934dadb558
commit c7971c063f
3 changed files with 25 additions and 26 deletions
datasets
lib
models
generators
utils

@ -3,6 +3,7 @@ from collections import defaultdict
from pathlib import Path
from typing import Union
from torchvision.datasets import VisionDataset
from torchvision.transforms import Normalize
import multiprocessing as mp
@ -20,7 +21,7 @@ from PIL import Image
from lib.utils.tools import write_to_shelve
class TrajDataShelve(Dataset):
class TrajDataShelve(VisionDataset):
@property
def map_shape(self):
@ -46,10 +47,22 @@ class TrajDataShelve(Dataset):
def __getitem__(self, item):
self._mutex.acquire()
with shelve.open(self.file_path) as d:
sample = d[str(item)]
img = d['data'][str(item)]
target = d['label'][str(item)]
d.close()
self._mutex.release()
return sample
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
class TrajDataset(Dataset):
@ -87,15 +100,6 @@ class TrajDataset(Dataset):
def __getitem__(self, item):
if self.mode.lower() == 'just_route':
raise NotImplementedError
trajectory = self.map.get_random_trajectory()
trajectory_space = trajectory.draw_in_array(self.map.shape)
label = choice([0, 1])
map_array = torch.as_tensor(self.map.as_array).float()
return (map_array, trajectory_space), label
# Produce an alternative.
while True:
trajectory = self.map.get_random_trajectory()
alternative = self.map.generate_alternative(trajectory)
@ -114,17 +118,13 @@ class TrajDataset(Dataset):
if self.mode == 'generator_all_in_map':
return np.concatenate((map_array, trajectory, label_as_array)), alternative
elif self.mode in ['vae_no_label_in_map', 'ae_no_label_in_map']:
elif self.mode in ['vae_no_label_in_map']:
return np.sum((map_array, trajectory, alternative), axis=0), 0
elif self.mode in ['generator_alt_no_label_in_map', 'generator_hom_no_label_in_map']:
return np.concatenate((map_array, trajectory)), alternative
elif self.mode == 'classifier_all_in_map':
return np.concatenate((map_array, trajectory, alternative)), label
elif self.mode == '_vectors':
raise NotImplementedError
return trajectory.vertices, alternative.vertices, label, self.mapname
raise ValueError(f'Mode was: {self.mode}')
def seed(self, seed):
@ -148,7 +148,7 @@ class TrajData(object):
def name(self):
return self.__class__.__name__
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, preprocessed=False, **_):
def __init__(self, map_root, length=100000, mode='', normalized=True, preprocessed=False, **_):
self.preprocessed = preprocessed
self.normalized = normalized
self.mode = mode

@ -79,15 +79,13 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def __init__(self, *params, issubclassed=False):
super(CNNRouteGeneratorModel, self).__init__(*params)
if False:
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root,
mode=self.hparams.data_param.mode,
preprocessed=self.hparams.data_param.use_preprocessed,
length=self.hparams.data_param.dataset_length)
self.criterion = nn.BCELoss(reduction='sum')
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root,
mode=self.hparams.data_param.mode,
preprocessed=self.hparams.data_param.use_preprocessed,
length=self.hparams.data_param.dataset_length)
self.dataset = MyMNIST()
self.criterion = nn.BCELoss(reduction='sum')
# Additional Attributes
###################################################

@ -7,6 +7,7 @@ from pytorch_lightning.loggers.test_tube import TestTubeLogger
from lib.utils.config import Config
import numpy as np
class Logger(LightningLoggerBase):
media_dir = 'media'