all debug and train running

This commit is contained in:
steffen
2020-03-04 20:47:20 +01:00
parent 3a6c65240f
commit cec3d54578
2 changed files with 50 additions and 6 deletions

View File

@ -3,7 +3,6 @@ from pathlib import Path
from typing import Union, List from typing import Union, List
import torch import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map from lib.objects.map import Map
@ -17,12 +16,14 @@ class TrajDataset(Dataset):
return self.map.as_array.shape return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100000, all_in_map=True, embedding_size=None, **kwargs): length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs):
super(TrajDataset, self).__init__() super(TrajDataset, self).__init__()
self.preserve_equal_samples = preserve_equal_samples
self.all_in_map = all_in_map self.all_in_map = all_in_map
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root self.maps_root = maps_root
self._len = length self._len = length
self.last_label = -1
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size) self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
@ -31,8 +32,16 @@ class TrajDataset(Dataset):
def __getitem__(self, item): def __getitem__(self, item):
trajectory = self.map.get_random_trajectory() trajectory = self.map.get_random_trajectory()
while True:
# TODO: Sanity Check this while true loop...
alternative = self.map.generate_alternative(trajectory) alternative = self.map.generate_alternative(trajectory)
label = choice([0, 1]) label = self.map.are_homotopic(trajectory, alternative)
if self.preserve_equal_samples and label == self.last_label:
continue
else:
break
self.last_label = label
if self.all_in_map: if self.all_in_map:
blank_trajectory_space = torch.zeros(self.map.shape) blank_trajectory_space = torch.zeros(self.map.shape)
blank_alternative_space = torch.zeros(self.map.shape) blank_alternative_space = torch.zeros(self.map.shape)
@ -41,7 +50,6 @@ class TrajDataset(Dataset):
blank_alternative_space[index] = 1 blank_alternative_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float() map_array = torch.as_tensor(self.map.as_array).float()
label = self.map.are_homotopic(trajectory, alternative)
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label) return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
else: else:
return trajectory.vertices, alternative.vertices, label, self.mapname return trajectory.vertices, alternative.vertices, label, self.mapname
@ -78,7 +86,8 @@ class TrajData(object):
# find max image size among available maps: # find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files])))) max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
all_in_map=self.all_in_map, embedding_size=max_map_size) all_in_map=self.all_in_map, embedding_size=max_map_size,
preserve_equal_samples=True)
for map_file in map_files]) for map_file in map_files])
@property @property

35
multi_run.py Normal file
View File

@ -0,0 +1,35 @@
import warnings
from lib.utils.config import Config
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
# Imports
# =============================================================================
from pathlib import Path
import os
if __name__ == '__main__':
# Model Settings
warnings.filterwarnings('ignore', category=FutureWarning)
# use_bias, activation, model, use_norm, max_epochs, filters
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
# Data Settings
data_shortcodes = ['mid', 'mid_5']
# Iteration over
for data_shortcode in data_shortcodes:
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
for seed in range(5):
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
model_use_bias=use_bias, model_use_norm=use_norm,
model_activation=activation, model_type=model,
model_filters=filters,
data_batch_size=512)
os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')