train running
This commit is contained in:
parent
daed810958
commit
6cc978e464
@ -46,7 +46,6 @@ class TrajDataset(Dataset):
|
||||
|
||||
while True:
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
# TODO: Sanity Check this while true loop...
|
||||
alternative = self.map.generate_alternative(trajectory)
|
||||
label = self.map.are_homotopic(trajectory, alternative)
|
||||
if self.preserve_equal_samples and label == self.last_label:
|
||||
@ -56,18 +55,13 @@ class TrajDataset(Dataset):
|
||||
|
||||
self.last_label = label
|
||||
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||
blank_alternative_space = torch.zeros(self.map.shape)
|
||||
for index in trajectory.vertices:
|
||||
blank_trajectory_space[index] = 1
|
||||
for index in alternative.vertices:
|
||||
blank_alternative_space[index] = 1
|
||||
|
||||
map_array = torch.as_tensor(self.map.as_array).float()
|
||||
if self.mode == 'separated_arrays':
|
||||
return (map_array, blank_trajectory_space, int(label)), blank_alternative_space
|
||||
return (map_array, trajectory.draw_in_array(self.map_shape), int(label)), \
|
||||
alternative.draw_in_array(self.map_shape)
|
||||
else:
|
||||
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
|
||||
return torch.cat((map_array, trajectory.draw_in_array(self.map_shape),
|
||||
alternative.draw_in_array(self.map_shape))), int(label)
|
||||
|
||||
elif self.mode == 'vectors':
|
||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||
|
@ -41,7 +41,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing
|
||||
kld_loss /= self.in_shape
|
||||
kld_loss /= reduce(mul, self.in_shape)
|
||||
|
||||
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||
return dict(loss=loss, log=dict(loss=loss,
|
||||
@ -65,7 +65,10 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
def validation_epoch_end(self, outputs: list):
|
||||
return self._test_val_epoch_end(outputs)
|
||||
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
evaluation = ROCEvaluation(plot_roc=True)
|
||||
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
@ -73,8 +76,9 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
|
||||
if test:
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
|
||||
plt.clf()
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
@ -88,6 +92,10 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
def test_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self._test_val_epoch_end(outputs, test=True)
|
||||
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
if self._disc is None:
|
||||
@ -247,8 +255,15 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
return z, mu, logvar
|
||||
|
||||
def generate_random(self, n=6):
|
||||
maps = [self.map_storage[choice(self.map_storage.keys)] for _ in range(n)]
|
||||
trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2)
|
||||
maps = torch.stack([x.as_2d_array for x in maps] * 2)
|
||||
labels = torch.as_tensor([0] * n + [1] * n)
|
||||
return maps, trajectories, labels, self._test_val_step(maps, trajectories, labels)
|
||||
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
|
||||
|
||||
trajectories = [x.get_random_trajectory() for x in maps] * 2
|
||||
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
|
||||
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories]
|
||||
trajectories = self._move_to_model_device(torch.stack(trajectories))
|
||||
|
||||
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
|
||||
maps = self._move_to_model_device(torch.stack(maps))
|
||||
|
||||
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
|
||||
return maps, trajectories, labels, self._test_val_step(([maps, trajectories], labels), -9999)
|
||||
|
@ -90,7 +90,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
# Data loading
|
||||
# =============================================================================
|
||||
# Map Object
|
||||
self.map_storage = MapStorage(self.hparams.data_param.map_root, load_all=True)
|
||||
self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
|
@ -1,4 +1,5 @@
|
||||
import shelve
|
||||
from collections import UserDict
|
||||
from pathlib import Path
|
||||
|
||||
import copy
|
||||
@ -69,7 +70,7 @@ class Map(object):
|
||||
|
||||
# Check pixels for their color (determine if walkable)
|
||||
for idx, value in np.ndenumerate(self.map_array):
|
||||
if value == self.white:
|
||||
if value != self.black:
|
||||
# IF walkable, add node
|
||||
graph.add_node(idx, count=0)
|
||||
# Fully connect to all surrounding neighbors
|
||||
@ -90,6 +91,7 @@ class Map(object):
|
||||
if image.mode != 'L':
|
||||
image = image.convert('L')
|
||||
map_array = np.expand_dims(np.array(image), axis=0)
|
||||
map_array = np.where(np.asarray(map_array) == cls.white, 1, 0)
|
||||
if embedding_size:
|
||||
assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}'
|
||||
embedding = np.zeros(embedding_size)
|
||||
@ -147,9 +149,9 @@ class Map(object):
|
||||
img = Image.new('L', (self.height, self.width), 0)
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
draw.polygon(polyline, outline=self.white, fill=self.white)
|
||||
draw.polygon(polyline, outline=1, fill=1)
|
||||
|
||||
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
|
||||
a = (np.asarray(img) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
|
||||
|
||||
if a:
|
||||
return V.ALTERNATIVE # Non-Homotoph
|
||||
@ -165,32 +167,25 @@ class Map(object):
|
||||
return dict(img=img, fig=fig, ax=ax)
|
||||
|
||||
|
||||
class MapStorage(object):
|
||||
class MapStorage(UserDict):
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
return list(self.data.keys())
|
||||
def keys_list(self):
|
||||
return list(super(MapStorage, self).keys())
|
||||
|
||||
def __init__(self, map_root, load_all=False):
|
||||
self.data = dict()
|
||||
def __init__(self, map_root, *args, **kwargs):
|
||||
super(MapStorage, self).__init__(*args, **kwargs)
|
||||
self.map_root = Path(map_root)
|
||||
if load_all:
|
||||
for map_file in self.map_root.glob('*.bmp'):
|
||||
_ = self[map_file.name]
|
||||
|
||||
def __getitem__(self, item):
|
||||
if item in self.data.keys():
|
||||
return self.data.get(item)
|
||||
else:
|
||||
current_map = Map().from_image(self.map_root / item)
|
||||
self.data.__setitem__(item, np.asarray(current_map))
|
||||
return self[item]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
map_files = list(self.map_root.glob('*.bmp'))
|
||||
self.max_map_size = (1, ) + tuple(
|
||||
reversed(
|
||||
tuple(
|
||||
map(
|
||||
max, *[Image.open(map_file).size for map_file in map_files])
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for map_file in map_files:
|
||||
current_map = Map().from_image(map_file, embedding_size=self.max_map_size)
|
||||
self.__setitem__(map_file.name, current_map)
|
||||
|
@ -41,6 +41,12 @@ class Trajectory(object):
|
||||
def as_paired_list(self):
|
||||
return list(zip(self._vertices[:-1], self._vertices[1:]))
|
||||
|
||||
def draw_in_array(self, shape):
|
||||
trajectory_space = np.zeros(shape)
|
||||
for index in self.vertices:
|
||||
trajectory_space[index] = 1
|
||||
return trajectory_space
|
||||
|
||||
@property
|
||||
def np_vertices(self):
|
||||
return [np.array(vertice) for vertice in self._vertices]
|
||||
|
@ -19,9 +19,9 @@ class GeneratorVisualizer(object):
|
||||
|
||||
def _build_column_dict_list(self):
|
||||
dict_list = []
|
||||
for idx in range(self.maps):
|
||||
image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives
|
||||
label = self.labels[idx]
|
||||
for idx in range(self.maps.shape[0]):
|
||||
image = (self.maps[idx] + self.trajectories[idx] + self.generated_alternatives[idx]).cpu().numpy().squeeze()
|
||||
label = int(self.labels[idx])
|
||||
dict_list.append(dict(image=image, label=label))
|
||||
half_size = int(len(dict_list) // 2)
|
||||
return dict_list[:half_size], dict_list[half_size:]
|
||||
@ -33,10 +33,10 @@ class GeneratorVisualizer(object):
|
||||
axes_pad=0.2, # pad between axes in inch.
|
||||
)
|
||||
|
||||
for idx in grid.axes_all:
|
||||
for idx in range(len(grid.axes_all)):
|
||||
row, col = divmod(idx, len(self.column_dict_list))
|
||||
current_image = self.column_dict_list[col]['image'][row]
|
||||
current_label = self.column_dict_list[col]['label'][row]
|
||||
current_image = self.column_dict_list[col][row]['image']
|
||||
current_label = self.column_dict_list[col][row]['label']
|
||||
grid[idx].imshow(current_image)
|
||||
grid[idx].title.set_text(current_label)
|
||||
fig.cbar_mode = 'single'
|
||||
|
9
main.py
9
main.py
@ -28,12 +28,12 @@ main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
|
||||
|
||||
# Main Parameters
|
||||
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
||||
|
||||
# Data Parameters
|
||||
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
|
||||
main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000, help="")
|
||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
|
||||
|
||||
@ -43,7 +43,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
|
||||
# Transformations
|
||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=12, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||
|
||||
@ -123,7 +123,8 @@ def run_lightning_loop(config_obj):
|
||||
model.save_to_disk(logger.log_dir)
|
||||
|
||||
# Evaluate It
|
||||
trainer.test()
|
||||
if config_obj.main.eval:
|
||||
trainer.test()
|
||||
|
||||
return model
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user