train running

This commit is contained in:
Steffen Illium 2020-03-09 21:41:50 +01:00
parent daed810958
commit 6cc978e464
7 changed files with 68 additions and 57 deletions

View File

@ -46,7 +46,6 @@ class TrajDataset(Dataset):
while True:
trajectory = self.map.get_random_trajectory()
# TODO: Sanity Check this while true loop...
alternative = self.map.generate_alternative(trajectory)
label = self.map.are_homotopic(trajectory, alternative)
if self.preserve_equal_samples and label == self.last_label:
@ -56,18 +55,13 @@ class TrajDataset(Dataset):
self.last_label = label
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
blank_trajectory_space = torch.zeros(self.map.shape)
blank_alternative_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
for index in alternative.vertices:
blank_alternative_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
if self.mode == 'separated_arrays':
return (map_array, blank_trajectory_space, int(label)), blank_alternative_space
return (map_array, trajectory.draw_in_array(self.map_shape), int(label)), \
alternative.draw_in_array(self.map_shape)
else:
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
return torch.cat((map_array, trajectory.draw_in_array(self.map_shape),
alternative.draw_in_array(self.map_shape))), int(label)
elif self.mode == 'vectors':
return trajectory.vertices, alternative.vertices, label, self.mapname

View File

@ -41,7 +41,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing
kld_loss /= self.in_shape
kld_loss /= reduce(mul, self.in_shape)
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
@ -65,7 +65,10 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs):
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def _test_val_epoch_end(self, outputs, test=False):
evaluation = ROCEvaluation(plot_roc=True)
pred_label = torch.cat([x['pred_label'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
@ -73,8 +76,9 @@ class CNNRouteGeneratorModel(LightningBaseModule):
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
if test:
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
plt.clf()
maps, trajectories, labels, val_restul_dict = self.generate_random()
@ -88,6 +92,10 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
@property
def discriminator(self):
if self._disc is None:
@ -247,8 +255,15 @@ class CNNRouteGeneratorModel(LightningBaseModule):
return z, mu, logvar
def generate_random(self, n=6):
maps = [self.map_storage[choice(self.map_storage.keys)] for _ in range(n)]
trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2)
maps = torch.stack([x.as_2d_array for x in maps] * 2)
labels = torch.as_tensor([0] * n + [1] * n)
return maps, trajectories, labels, self._test_val_step(maps, trajectories, labels)
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
trajectories = [x.get_random_trajectory() for x in maps] * 2
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories]
trajectories = self._move_to_model_device(torch.stack(trajectories))
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
maps = self._move_to_model_device(torch.stack(maps))
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
return maps, trajectories, labels, self._test_val_step(([maps, trajectories], labels), -9999)

View File

@ -90,7 +90,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Data loading
# =============================================================================
# Map Object
self.map_storage = MapStorage(self.hparams.data_param.map_root, load_all=True)
self.map_storage = MapStorage(self.hparams.data_param.map_root)
def size(self):
return self.shape

View File

@ -1,4 +1,5 @@
import shelve
from collections import UserDict
from pathlib import Path
import copy
@ -69,7 +70,7 @@ class Map(object):
# Check pixels for their color (determine if walkable)
for idx, value in np.ndenumerate(self.map_array):
if value == self.white:
if value != self.black:
# IF walkable, add node
graph.add_node(idx, count=0)
# Fully connect to all surrounding neighbors
@ -90,6 +91,7 @@ class Map(object):
if image.mode != 'L':
image = image.convert('L')
map_array = np.expand_dims(np.array(image), axis=0)
map_array = np.where(np.asarray(map_array) == cls.white, 1, 0)
if embedding_size:
assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}'
embedding = np.zeros(embedding_size)
@ -147,9 +149,9 @@ class Map(object):
img = Image.new('L', (self.height, self.width), 0)
draw = ImageDraw.Draw(img)
draw.polygon(polyline, outline=self.white, fill=self.white)
draw.polygon(polyline, outline=1, fill=1)
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
a = (np.asarray(img) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
if a:
return V.ALTERNATIVE # Non-Homotoph
@ -165,32 +167,25 @@ class Map(object):
return dict(img=img, fig=fig, ax=ax)
class MapStorage(object):
class MapStorage(UserDict):
@property
def keys(self):
return list(self.data.keys())
def keys_list(self):
return list(super(MapStorage, self).keys())
def __init__(self, map_root, load_all=False):
self.data = dict()
def __init__(self, map_root, *args, **kwargs):
super(MapStorage, self).__init__(*args, **kwargs)
self.map_root = Path(map_root)
if load_all:
for map_file in self.map_root.glob('*.bmp'):
_ = self[map_file.name]
def __getitem__(self, item):
if item in self.data.keys():
return self.data.get(item)
else:
current_map = Map().from_image(self.map_root / item)
self.data.__setitem__(item, np.asarray(current_map))
return self[item]
map_files = list(self.map_root.glob('*.bmp'))
self.max_map_size = (1, ) + tuple(
reversed(
tuple(
map(
max, *[Image.open(map_file).size for map_file in map_files])
)
)
)
for map_file in map_files:
current_map = Map().from_image(map_file, embedding_size=self.max_map_size)
self.__setitem__(map_file.name, current_map)

View File

@ -41,6 +41,12 @@ class Trajectory(object):
def as_paired_list(self):
return list(zip(self._vertices[:-1], self._vertices[1:]))
def draw_in_array(self, shape):
trajectory_space = np.zeros(shape)
for index in self.vertices:
trajectory_space[index] = 1
return trajectory_space
@property
def np_vertices(self):
return [np.array(vertice) for vertice in self._vertices]

View File

@ -19,9 +19,9 @@ class GeneratorVisualizer(object):
def _build_column_dict_list(self):
dict_list = []
for idx in range(self.maps):
image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives
label = self.labels[idx]
for idx in range(self.maps.shape[0]):
image = (self.maps[idx] + self.trajectories[idx] + self.generated_alternatives[idx]).cpu().numpy().squeeze()
label = int(self.labels[idx])
dict_list.append(dict(image=image, label=label))
half_size = int(len(dict_list) // 2)
return dict_list[:half_size], dict_list[half_size:]
@ -33,10 +33,10 @@ class GeneratorVisualizer(object):
axes_pad=0.2, # pad between axes in inch.
)
for idx in grid.axes_all:
for idx in range(len(grid.axes_all)):
row, col = divmod(idx, len(self.column_dict_list))
current_image = self.column_dict_list[col]['image'][row]
current_label = self.column_dict_list[col]['label'][row]
current_image = self.column_dict_list[col][row]['image']
current_label = self.column_dict_list[col][row]['label']
grid[idx].imshow(current_image)
grid[idx].title.set_text(current_label)
fig.cbar_mode = 'single'

View File

@ -28,12 +28,12 @@ main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
# Main Parameters
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
@ -43,7 +43,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
# Transformations
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=12, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
@ -123,7 +123,8 @@ def run_lightning_loop(config_obj):
model.save_to_disk(logger.log_dir)
# Evaluate It
trainer.test()
if config_obj.main.eval:
trainer.test()
return model