project Refactor, CNN Classifier Basics

This commit is contained in:
Steffen Illium
2020-03-08 23:46:02 +01:00
parent 75e8a61628
commit cd4fdf2de3
20 changed files with 441 additions and 239 deletions

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import Union, List
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map
@ -16,10 +17,11 @@ class TrajDataset(Dataset):
return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs):
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs):
super(TrajDataset, self).__init__()
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
self.preserve_equal_samples = preserve_equal_samples
self.all_in_map = all_in_map
self.mode = mode
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
@ -31,8 +33,19 @@ class TrajDataset(Dataset):
return self._len
def __getitem__(self, item):
trajectory = self.map.get_random_trajectory()
if self.mode.lower() == 'just_route':
trajectory = self.map.get_random_trajectory()
label = choice([0, 1])
blank_trajectory_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
return (map_array, blank_trajectory_space), label
while True:
trajectory = self.map.get_random_trajectory()
# TODO: Sanity Check this while true loop...
alternative = self.map.generate_alternative(trajectory)
label = self.map.are_homotopic(trajectory, alternative)
@ -42,18 +55,26 @@ class TrajDataset(Dataset):
break
self.last_label = label
if self.all_in_map:
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
blank_trajectory_space = torch.zeros(self.map.shape)
blank_alternative_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
for index in alternative.vertices:
blank_alternative_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
else:
if self.mode == 'separated_arrays':
return (map_array, blank_trajectory_space, int(label)), blank_alternative_space
else:
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
elif self.mode == 'vectors':
return trajectory.vertices, alternative.vertices, label, self.mapname
else:
raise ValueError
class TrajData(object):
@property
@ -64,7 +85,7 @@ class TrajData(object):
def map_shapes_max(self):
shapes = self.map_shapes
shape_list = list(map(max, zip(*shapes)))
if self.all_in_map:
if self.mode == 'all_in_map':
shape_list[0] += 2
return shape_list
@ -72,10 +93,10 @@ class TrajData(object):
def name(self):
return self.__class__.__name__
def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_):
def __init__(self, map_root, length=100000, mode='separated_arrays', **_):
self.all_in_map = all_in_map
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
self.mode = mode
self.maps_root = Path(map_root)
self.length = length
self._dataset = self._load_datasets()
@ -86,8 +107,8 @@ class TrajData(object):
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
all_in_map=self.all_in_map, embedding_size=max_map_size,
preserve_equal_samples=False)
mode=self.mode, embedding_size=max_map_size,
preserve_equal_samples=True)
for map_file in map_files])
@property