train running
This commit is contained in:
parent
daed810958
commit
6cc978e464
@ -46,7 +46,6 @@ class TrajDataset(Dataset):
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
trajectory = self.map.get_random_trajectory()
|
trajectory = self.map.get_random_trajectory()
|
||||||
# TODO: Sanity Check this while true loop...
|
|
||||||
alternative = self.map.generate_alternative(trajectory)
|
alternative = self.map.generate_alternative(trajectory)
|
||||||
label = self.map.are_homotopic(trajectory, alternative)
|
label = self.map.are_homotopic(trajectory, alternative)
|
||||||
if self.preserve_equal_samples and label == self.last_label:
|
if self.preserve_equal_samples and label == self.last_label:
|
||||||
@ -56,18 +55,13 @@ class TrajDataset(Dataset):
|
|||||||
|
|
||||||
self.last_label = label
|
self.last_label = label
|
||||||
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
||||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
|
||||||
blank_alternative_space = torch.zeros(self.map.shape)
|
|
||||||
for index in trajectory.vertices:
|
|
||||||
blank_trajectory_space[index] = 1
|
|
||||||
for index in alternative.vertices:
|
|
||||||
blank_alternative_space[index] = 1
|
|
||||||
|
|
||||||
map_array = torch.as_tensor(self.map.as_array).float()
|
map_array = torch.as_tensor(self.map.as_array).float()
|
||||||
if self.mode == 'separated_arrays':
|
if self.mode == 'separated_arrays':
|
||||||
return (map_array, blank_trajectory_space, int(label)), blank_alternative_space
|
return (map_array, trajectory.draw_in_array(self.map_shape), int(label)), \
|
||||||
|
alternative.draw_in_array(self.map_shape)
|
||||||
else:
|
else:
|
||||||
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
|
return torch.cat((map_array, trajectory.draw_in_array(self.map_shape),
|
||||||
|
alternative.draw_in_array(self.map_shape))), int(label)
|
||||||
|
|
||||||
elif self.mode == 'vectors':
|
elif self.mode == 'vectors':
|
||||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||||
|
@ -41,7 +41,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
# Dimensional Resizing
|
# Dimensional Resizing
|
||||||
kld_loss /= self.in_shape
|
kld_loss /= reduce(mul, self.in_shape)
|
||||||
|
|
||||||
loss = (kld_loss + discriminated_bce_loss) / 2
|
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||||
return dict(loss=loss, log=dict(loss=loss,
|
return dict(loss=loss, log=dict(loss=loss,
|
||||||
@ -65,7 +65,10 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
def validation_step(self, *args):
|
def validation_step(self, *args):
|
||||||
return self._test_val_step(*args)
|
return self._test_val_step(*args)
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
def validation_epoch_end(self, outputs: list):
|
||||||
|
return self._test_val_epoch_end(outputs)
|
||||||
|
|
||||||
|
def _test_val_epoch_end(self, outputs, test=False):
|
||||||
evaluation = ROCEvaluation(plot_roc=True)
|
evaluation = ROCEvaluation(plot_roc=True)
|
||||||
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||||
@ -73,6 +76,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
|
|
||||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||||
|
if test:
|
||||||
# self.logger.log_metrics(score_dict)
|
# self.logger.log_metrics(score_dict)
|
||||||
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
|
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
|
||||||
plt.clf()
|
plt.clf()
|
||||||
@ -88,6 +92,10 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
def test_step(self, *args):
|
def test_step(self, *args):
|
||||||
return self._test_val_step(*args)
|
return self._test_val_step(*args)
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
return self._test_val_epoch_end(outputs, test=True)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def discriminator(self):
|
def discriminator(self):
|
||||||
if self._disc is None:
|
if self._disc is None:
|
||||||
@ -247,8 +255,15 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
return z, mu, logvar
|
return z, mu, logvar
|
||||||
|
|
||||||
def generate_random(self, n=6):
|
def generate_random(self, n=6):
|
||||||
maps = [self.map_storage[choice(self.map_storage.keys)] for _ in range(n)]
|
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
|
||||||
trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2)
|
|
||||||
maps = torch.stack([x.as_2d_array for x in maps] * 2)
|
trajectories = [x.get_random_trajectory() for x in maps] * 2
|
||||||
labels = torch.as_tensor([0] * n + [1] * n)
|
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
|
||||||
return maps, trajectories, labels, self._test_val_step(maps, trajectories, labels)
|
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories]
|
||||||
|
trajectories = self._move_to_model_device(torch.stack(trajectories))
|
||||||
|
|
||||||
|
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
|
||||||
|
maps = self._move_to_model_device(torch.stack(maps))
|
||||||
|
|
||||||
|
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
|
||||||
|
return maps, trajectories, labels, self._test_val_step(([maps, trajectories], labels), -9999)
|
||||||
|
@ -90,7 +90,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
# Data loading
|
# Data loading
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Map Object
|
# Map Object
|
||||||
self.map_storage = MapStorage(self.hparams.data_param.map_root, load_all=True)
|
self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return self.shape
|
return self.shape
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import shelve
|
import shelve
|
||||||
|
from collections import UserDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
@ -69,7 +70,7 @@ class Map(object):
|
|||||||
|
|
||||||
# Check pixels for their color (determine if walkable)
|
# Check pixels for their color (determine if walkable)
|
||||||
for idx, value in np.ndenumerate(self.map_array):
|
for idx, value in np.ndenumerate(self.map_array):
|
||||||
if value == self.white:
|
if value != self.black:
|
||||||
# IF walkable, add node
|
# IF walkable, add node
|
||||||
graph.add_node(idx, count=0)
|
graph.add_node(idx, count=0)
|
||||||
# Fully connect to all surrounding neighbors
|
# Fully connect to all surrounding neighbors
|
||||||
@ -90,6 +91,7 @@ class Map(object):
|
|||||||
if image.mode != 'L':
|
if image.mode != 'L':
|
||||||
image = image.convert('L')
|
image = image.convert('L')
|
||||||
map_array = np.expand_dims(np.array(image), axis=0)
|
map_array = np.expand_dims(np.array(image), axis=0)
|
||||||
|
map_array = np.where(np.asarray(map_array) == cls.white, 1, 0)
|
||||||
if embedding_size:
|
if embedding_size:
|
||||||
assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}'
|
assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}'
|
||||||
embedding = np.zeros(embedding_size)
|
embedding = np.zeros(embedding_size)
|
||||||
@ -147,9 +149,9 @@ class Map(object):
|
|||||||
img = Image.new('L', (self.height, self.width), 0)
|
img = Image.new('L', (self.height, self.width), 0)
|
||||||
draw = ImageDraw.Draw(img)
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
draw.polygon(polyline, outline=self.white, fill=self.white)
|
draw.polygon(polyline, outline=1, fill=1)
|
||||||
|
|
||||||
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
|
a = (np.asarray(img) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
|
||||||
|
|
||||||
if a:
|
if a:
|
||||||
return V.ALTERNATIVE # Non-Homotoph
|
return V.ALTERNATIVE # Non-Homotoph
|
||||||
@ -165,32 +167,25 @@ class Map(object):
|
|||||||
return dict(img=img, fig=fig, ax=ax)
|
return dict(img=img, fig=fig, ax=ax)
|
||||||
|
|
||||||
|
|
||||||
class MapStorage(object):
|
class MapStorage(UserDict):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keys(self):
|
def keys_list(self):
|
||||||
return list(self.data.keys())
|
return list(super(MapStorage, self).keys())
|
||||||
|
|
||||||
def __init__(self, map_root, load_all=False):
|
def __init__(self, map_root, *args, **kwargs):
|
||||||
self.data = dict()
|
super(MapStorage, self).__init__(*args, **kwargs)
|
||||||
self.map_root = Path(map_root)
|
self.map_root = Path(map_root)
|
||||||
if load_all:
|
map_files = list(self.map_root.glob('*.bmp'))
|
||||||
for map_file in self.map_root.glob('*.bmp'):
|
self.max_map_size = (1, ) + tuple(
|
||||||
_ = self[map_file.name]
|
reversed(
|
||||||
|
tuple(
|
||||||
def __getitem__(self, item):
|
map(
|
||||||
if item in self.data.keys():
|
max, *[Image.open(map_file).size for map_file in map_files])
|
||||||
return self.data.get(item)
|
)
|
||||||
else:
|
)
|
||||||
current_map = Map().from_image(self.map_root / item)
|
)
|
||||||
self.data.__setitem__(item, np.asarray(current_map))
|
|
||||||
return self[item]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for map_file in map_files:
|
||||||
|
current_map = Map().from_image(map_file, embedding_size=self.max_map_size)
|
||||||
|
self.__setitem__(map_file.name, current_map)
|
||||||
|
@ -41,6 +41,12 @@ class Trajectory(object):
|
|||||||
def as_paired_list(self):
|
def as_paired_list(self):
|
||||||
return list(zip(self._vertices[:-1], self._vertices[1:]))
|
return list(zip(self._vertices[:-1], self._vertices[1:]))
|
||||||
|
|
||||||
|
def draw_in_array(self, shape):
|
||||||
|
trajectory_space = np.zeros(shape)
|
||||||
|
for index in self.vertices:
|
||||||
|
trajectory_space[index] = 1
|
||||||
|
return trajectory_space
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def np_vertices(self):
|
def np_vertices(self):
|
||||||
return [np.array(vertice) for vertice in self._vertices]
|
return [np.array(vertice) for vertice in self._vertices]
|
||||||
|
@ -19,9 +19,9 @@ class GeneratorVisualizer(object):
|
|||||||
|
|
||||||
def _build_column_dict_list(self):
|
def _build_column_dict_list(self):
|
||||||
dict_list = []
|
dict_list = []
|
||||||
for idx in range(self.maps):
|
for idx in range(self.maps.shape[0]):
|
||||||
image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives
|
image = (self.maps[idx] + self.trajectories[idx] + self.generated_alternatives[idx]).cpu().numpy().squeeze()
|
||||||
label = self.labels[idx]
|
label = int(self.labels[idx])
|
||||||
dict_list.append(dict(image=image, label=label))
|
dict_list.append(dict(image=image, label=label))
|
||||||
half_size = int(len(dict_list) // 2)
|
half_size = int(len(dict_list) // 2)
|
||||||
return dict_list[:half_size], dict_list[half_size:]
|
return dict_list[:half_size], dict_list[half_size:]
|
||||||
@ -33,10 +33,10 @@ class GeneratorVisualizer(object):
|
|||||||
axes_pad=0.2, # pad between axes in inch.
|
axes_pad=0.2, # pad between axes in inch.
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx in grid.axes_all:
|
for idx in range(len(grid.axes_all)):
|
||||||
row, col = divmod(idx, len(self.column_dict_list))
|
row, col = divmod(idx, len(self.column_dict_list))
|
||||||
current_image = self.column_dict_list[col]['image'][row]
|
current_image = self.column_dict_list[col][row]['image']
|
||||||
current_label = self.column_dict_list[col]['label'][row]
|
current_label = self.column_dict_list[col][row]['label']
|
||||||
grid[idx].imshow(current_image)
|
grid[idx].imshow(current_image)
|
||||||
grid[idx].title.set_text(current_label)
|
grid[idx].title.set_text(current_label)
|
||||||
fig.cbar_mode = 'single'
|
fig.cbar_mode = 'single'
|
||||||
|
7
main.py
7
main.py
@ -28,12 +28,12 @@ main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
|
|||||||
|
|
||||||
# Main Parameters
|
# Main Parameters
|
||||||
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||||
|
|
||||||
# Data Parameters
|
# Data Parameters
|
||||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
|
main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
|
|||||||
# Transformations
|
# Transformations
|
||||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||||
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
|
main_arg_parser.add_argument("--train_epochs", type=int, default=12, help="")
|
||||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="")
|
main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="")
|
||||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||||
|
|
||||||
@ -123,6 +123,7 @@ def run_lightning_loop(config_obj):
|
|||||||
model.save_to_disk(logger.log_dir)
|
model.save_to_disk(logger.log_dir)
|
||||||
|
|
||||||
# Evaluate It
|
# Evaluate It
|
||||||
|
if config_obj.main.eval:
|
||||||
trainer.test()
|
trainer.test()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user