diff --git a/datasets/trajectory_dataset.py b/datasets/trajectory_dataset.py index de019d2..69a80de 100644 --- a/datasets/trajectory_dataset.py +++ b/datasets/trajectory_dataset.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import Union, List import torch -from random import choice from torch.utils.data import ConcatDataset, Dataset from lib.objects.map import Map @@ -17,12 +16,14 @@ class TrajDataset(Dataset): return self.map.as_array.shape def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', - length=100000, all_in_map=True, embedding_size=None, **kwargs): + length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs): super(TrajDataset, self).__init__() + self.preserve_equal_samples = preserve_equal_samples self.all_in_map = all_in_map self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.maps_root = maps_root self._len = length + self.last_label = -1 self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size) @@ -31,8 +32,16 @@ class TrajDataset(Dataset): def __getitem__(self, item): trajectory = self.map.get_random_trajectory() - alternative = self.map.generate_alternative(trajectory) - label = choice([0, 1]) + while True: + # TODO: Sanity Check this while true loop... + alternative = self.map.generate_alternative(trajectory) + label = self.map.are_homotopic(trajectory, alternative) + if self.preserve_equal_samples and label == self.last_label: + continue + else: + break + + self.last_label = label if self.all_in_map: blank_trajectory_space = torch.zeros(self.map.shape) blank_alternative_space = torch.zeros(self.map.shape) @@ -41,7 +50,6 @@ class TrajDataset(Dataset): blank_alternative_space[index] = 1 map_array = torch.as_tensor(self.map.as_array).float() - label = self.map.are_homotopic(trajectory, alternative) return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label) else: return trajectory.vertices, alternative.vertices, label, self.mapname @@ -78,7 +86,8 @@ class TrajData(object): # find max image size among available maps: max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files])))) return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, - all_in_map=self.all_in_map, embedding_size=max_map_size) + all_in_map=self.all_in_map, embedding_size=max_map_size, + preserve_equal_samples=True) for map_file in map_files]) @property diff --git a/multi_run.py b/multi_run.py new file mode 100644 index 0000000..cf0eaef --- /dev/null +++ b/multi_run.py @@ -0,0 +1,35 @@ +import warnings + +from lib.utils.config import Config + +warnings.filterwarnings('ignore', category=FutureWarning) +warnings.filterwarnings('ignore', category=UserWarning) + +# Imports +# ============================================================================= +from pathlib import Path +import os + + +if __name__ == '__main__': + + # Model Settings + warnings.filterwarnings('ignore', category=FutureWarning) + # use_bias, activation, model, use_norm, max_epochs, filters + cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]] + # use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters + + # Data Settings + data_shortcodes = ['mid', 'mid_5'] + + # Iteration over + for data_shortcode in data_shortcodes: + for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]: + for seed in range(5): + arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs, + model_use_bias=use_bias, model_use_norm=use_norm, + model_activation=activation, model_type=model, + model_filters=filters, + data_batch_size=512) + + os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')