Kurz vorm durchdrehen

This commit is contained in:
Si11ium
2020-03-11 17:10:19 +01:00
parent 1b5a7dc69e
commit 1f4edae95c
12 changed files with 157 additions and 93 deletions

View File

@@ -18,10 +18,12 @@ class TrajDataset(Dataset):
def map_shape(self):
return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs):
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', normalized=True,
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
**kwargs):
super(TrajDataset, self).__init__()
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
self.normalized = normalized
self.preserve_equal_samples = preserve_equal_samples
self.mode = mode
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
@@ -58,6 +60,10 @@ class TrajDataset(Dataset):
trajectory = trajectory.draw_in_array(self.map_shape)
alternative = alternative.draw_in_array(self.map_shape)
if self.mode == 'separated_arrays':
if self.normalized:
map_array = map_array / V.WHITE
trajectory = trajectory / V.WHITE
alternative = alternative / V.WHITE
return (map_array, trajectory, label), alternative
else:
return np.concatenate((map_array, trajectory, alternative)), label
@@ -86,8 +92,9 @@ class TrajData(object):
def name(self):
return self.__class__.__name__
def __init__(self, map_root, length=100000, mode='separated_arrays', **_):
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, **_):
self.normalized = normalized
self.mode = mode
self.maps_root = Path(map_root)
self.length = length
@@ -100,7 +107,7 @@ class TrajData(object):
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size,
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
preserve_equal_samples=True)
for map_file in map_files])