Train Active

This commit is contained in:
Si11ium
2020-03-03 15:10:17 +01:00
parent 44f6589259
commit 1f612a968c
13 changed files with 102 additions and 98 deletions

View File

@ -7,7 +7,7 @@ from random import choice
from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map
from lib.preprocessing.generator import Generator
from PIL import Image
class TrajDataset(Dataset):
@ -17,14 +17,14 @@ class TrajDataset(Dataset):
return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100000, all_in_map=True, **kwargs):
length=100000, all_in_map=True, embedding_size=None, **kwargs):
super(TrajDataset, self).__init__()
self.all_in_map = all_in_map
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname)
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
def __len__(self):
return self._len
@ -35,15 +35,14 @@ class TrajDataset(Dataset):
label = choice([0, 1])
if self.all_in_map:
blank_trajectory_space = torch.zeros(self.map.shape)
blank_trajectory_space[trajectory.vertices] = 1
blank_alternative_space = torch.zeros(self.map.shape)
blank_alternative_space[trajectory.np_vertices] = 1
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
blank_alternative_space[index] = 1
map_array = torch.as_tensor(self.map.as_array)
map_array = torch.as_tensor(self.map.as_array).float()
label = self.map.are_homotopic(trajectory, alternative)
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), label
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
else:
return trajectory.vertices, alternative.vertices, label, self.mapname
@ -56,7 +55,10 @@ class TrajData(object):
@property
def map_shapes_max(self):
shapes = self.map_shapes
return list(map(max, zip(*shapes)))
shape_list = list(map(max, zip(*shapes)))
if self.all_in_map:
shape_list[0] += 2
return shape_list
@property
def name(self):
@ -72,8 +74,12 @@ class TrajData(object):
def _load_datasets(self):
map_files = list(self.maps_root.glob('*.bmp'))
equal_split = int(self.length // len(map_files))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
all_in_map=self.all_in_map) for map_image in map_files])
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
all_in_map=self.all_in_map, embedding_size=max_map_size)
for map_file in map_files])
@property
def train_dataset(self):