VAE Debugging of Route Generator
This commit is contained in:
@ -3,6 +3,7 @@ from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.transforms import Normalize
|
||||
|
||||
import multiprocessing as mp
|
||||
@ -20,7 +21,7 @@ from PIL import Image
|
||||
from lib.utils.tools import write_to_shelve
|
||||
|
||||
|
||||
class TrajDataShelve(Dataset):
|
||||
class TrajDataShelve(VisionDataset):
|
||||
|
||||
@property
|
||||
def map_shape(self):
|
||||
@ -46,10 +47,22 @@ class TrajDataShelve(Dataset):
|
||||
def __getitem__(self, item):
|
||||
self._mutex.acquire()
|
||||
with shelve.open(self.file_path) as d:
|
||||
sample = d[str(item)]
|
||||
img = d['data'][str(item)]
|
||||
target = d['label'][str(item)]
|
||||
d.close()
|
||||
self._mutex.release()
|
||||
return sample
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img.numpy(), mode='L')
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
|
||||
class TrajDataset(Dataset):
|
||||
@ -87,15 +100,6 @@ class TrajDataset(Dataset):
|
||||
|
||||
def __getitem__(self, item):
|
||||
|
||||
if self.mode.lower() == 'just_route':
|
||||
raise NotImplementedError
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
trajectory_space = trajectory.draw_in_array(self.map.shape)
|
||||
label = choice([0, 1])
|
||||
map_array = torch.as_tensor(self.map.as_array).float()
|
||||
return (map_array, trajectory_space), label
|
||||
|
||||
# Produce an alternative.
|
||||
while True:
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
alternative = self.map.generate_alternative(trajectory)
|
||||
@ -114,17 +118,13 @@ class TrajDataset(Dataset):
|
||||
|
||||
if self.mode == 'generator_all_in_map':
|
||||
return np.concatenate((map_array, trajectory, label_as_array)), alternative
|
||||
elif self.mode in ['vae_no_label_in_map', 'ae_no_label_in_map']:
|
||||
elif self.mode in ['vae_no_label_in_map']:
|
||||
return np.sum((map_array, trajectory, alternative), axis=0), 0
|
||||
elif self.mode in ['generator_alt_no_label_in_map', 'generator_hom_no_label_in_map']:
|
||||
return np.concatenate((map_array, trajectory)), alternative
|
||||
elif self.mode == 'classifier_all_in_map':
|
||||
return np.concatenate((map_array, trajectory, alternative)), label
|
||||
|
||||
elif self.mode == '_vectors':
|
||||
raise NotImplementedError
|
||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||
|
||||
raise ValueError(f'Mode was: {self.mode}')
|
||||
|
||||
def seed(self, seed):
|
||||
@ -148,7 +148,7 @@ class TrajData(object):
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, preprocessed=False, **_):
|
||||
def __init__(self, map_root, length=100000, mode='', normalized=True, preprocessed=False, **_):
|
||||
self.preprocessed = preprocessed
|
||||
self.normalized = normalized
|
||||
self.mode = mode
|
||||
|
@ -79,15 +79,13 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
def __init__(self, *params, issubclassed=False):
|
||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
if False:
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root,
|
||||
mode=self.hparams.data_param.mode,
|
||||
preprocessed=self.hparams.data_param.use_preprocessed,
|
||||
length=self.hparams.data_param.dataset_length)
|
||||
self.criterion = nn.BCELoss(reduction='sum')
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root,
|
||||
mode=self.hparams.data_param.mode,
|
||||
preprocessed=self.hparams.data_param.use_preprocessed,
|
||||
length=self.hparams.data_param.dataset_length)
|
||||
|
||||
self.dataset = MyMNIST()
|
||||
self.criterion = nn.BCELoss(reduction='sum')
|
||||
|
||||
# Additional Attributes
|
||||
###################################################
|
||||
|
@ -7,6 +7,7 @@ from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||
from lib.utils.config import Config
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Logger(LightningLoggerBase):
|
||||
|
||||
media_dir = 'media'
|
||||
|
Reference in New Issue
Block a user