all debug and train running
This commit is contained in:
@ -3,7 +3,6 @@ from pathlib import Path
|
|||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from random import choice
|
|
||||||
from torch.utils.data import ConcatDataset, Dataset
|
from torch.utils.data import ConcatDataset, Dataset
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
@ -17,12 +16,14 @@ class TrajDataset(Dataset):
|
|||||||
return self.map.as_array.shape
|
return self.map.as_array.shape
|
||||||
|
|
||||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
||||||
length=100000, all_in_map=True, embedding_size=None, **kwargs):
|
length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs):
|
||||||
super(TrajDataset, self).__init__()
|
super(TrajDataset, self).__init__()
|
||||||
|
self.preserve_equal_samples = preserve_equal_samples
|
||||||
self.all_in_map = all_in_map
|
self.all_in_map = all_in_map
|
||||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||||
self.maps_root = maps_root
|
self.maps_root = maps_root
|
||||||
self._len = length
|
self._len = length
|
||||||
|
self.last_label = -1
|
||||||
|
|
||||||
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
||||||
|
|
||||||
@ -31,8 +32,16 @@ class TrajDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
trajectory = self.map.get_random_trajectory()
|
trajectory = self.map.get_random_trajectory()
|
||||||
alternative = self.map.generate_alternative(trajectory)
|
while True:
|
||||||
label = choice([0, 1])
|
# TODO: Sanity Check this while true loop...
|
||||||
|
alternative = self.map.generate_alternative(trajectory)
|
||||||
|
label = self.map.are_homotopic(trajectory, alternative)
|
||||||
|
if self.preserve_equal_samples and label == self.last_label:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.last_label = label
|
||||||
if self.all_in_map:
|
if self.all_in_map:
|
||||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||||
blank_alternative_space = torch.zeros(self.map.shape)
|
blank_alternative_space = torch.zeros(self.map.shape)
|
||||||
@ -41,7 +50,6 @@ class TrajDataset(Dataset):
|
|||||||
blank_alternative_space[index] = 1
|
blank_alternative_space[index] = 1
|
||||||
|
|
||||||
map_array = torch.as_tensor(self.map.as_array).float()
|
map_array = torch.as_tensor(self.map.as_array).float()
|
||||||
label = self.map.are_homotopic(trajectory, alternative)
|
|
||||||
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
|
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
|
||||||
else:
|
else:
|
||||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||||
@ -78,7 +86,8 @@ class TrajData(object):
|
|||||||
# find max image size among available maps:
|
# find max image size among available maps:
|
||||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||||
all_in_map=self.all_in_map, embedding_size=max_map_size)
|
all_in_map=self.all_in_map, embedding_size=max_map_size,
|
||||||
|
preserve_equal_samples=True)
|
||||||
for map_file in map_files])
|
for map_file in map_files])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
35
multi_run.py
Normal file
35
multi_run.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
|
from lib.utils.config import Config
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
|
# Imports
|
||||||
|
# =============================================================================
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# Model Settings
|
||||||
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
|
# use_bias, activation, model, use_norm, max_epochs, filters
|
||||||
|
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
|
||||||
|
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
|
||||||
|
|
||||||
|
# Data Settings
|
||||||
|
data_shortcodes = ['mid', 'mid_5']
|
||||||
|
|
||||||
|
# Iteration over
|
||||||
|
for data_shortcode in data_shortcodes:
|
||||||
|
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
|
||||||
|
for seed in range(5):
|
||||||
|
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
|
||||||
|
model_use_bias=use_bias, model_use_norm=use_norm,
|
||||||
|
model_activation=activation, model_type=model,
|
||||||
|
model_filters=filters,
|
||||||
|
data_batch_size=512)
|
||||||
|
|
||||||
|
os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')
|
Reference in New Issue
Block a user