all debug and train running

This commit is contained in:
steffen
2020-03-04 20:47:20 +01:00
parent 3a6c65240f
commit cec3d54578
2 changed files with 50 additions and 6 deletions

View File

@ -3,7 +3,6 @@ from pathlib import Path
from typing import Union, List
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map
@ -17,12 +16,14 @@ class TrajDataset(Dataset):
return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100000, all_in_map=True, embedding_size=None, **kwargs):
length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs):
super(TrajDataset, self).__init__()
self.preserve_equal_samples = preserve_equal_samples
self.all_in_map = all_in_map
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
self.last_label = -1
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
@ -31,8 +32,16 @@ class TrajDataset(Dataset):
def __getitem__(self, item):
trajectory = self.map.get_random_trajectory()
while True:
# TODO: Sanity Check this while true loop...
alternative = self.map.generate_alternative(trajectory)
label = choice([0, 1])
label = self.map.are_homotopic(trajectory, alternative)
if self.preserve_equal_samples and label == self.last_label:
continue
else:
break
self.last_label = label
if self.all_in_map:
blank_trajectory_space = torch.zeros(self.map.shape)
blank_alternative_space = torch.zeros(self.map.shape)
@ -41,7 +50,6 @@ class TrajDataset(Dataset):
blank_alternative_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
label = self.map.are_homotopic(trajectory, alternative)
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
else:
return trajectory.vertices, alternative.vertices, label, self.mapname
@ -78,7 +86,8 @@ class TrajData(object):
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
all_in_map=self.all_in_map, embedding_size=max_map_size)
all_in_map=self.all_in_map, embedding_size=max_map_size,
preserve_equal_samples=True)
for map_file in map_files])
@property

35
multi_run.py Normal file
View File

@ -0,0 +1,35 @@
import warnings
from lib.utils.config import Config
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
# Imports
# =============================================================================
from pathlib import Path
import os
if __name__ == '__main__':
# Model Settings
warnings.filterwarnings('ignore', category=FutureWarning)
# use_bias, activation, model, use_norm, max_epochs, filters
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
# Data Settings
data_shortcodes = ['mid', 'mid_5']
# Iteration over
for data_shortcode in data_shortcodes:
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
for seed in range(5):
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
model_use_bias=use_bias, model_use_norm=use_norm,
model_activation=activation, model_type=model,
model_filters=filters,
data_batch_size=512)
os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')