Kurz vorm durchdrehen
This commit is contained in:
@@ -18,10 +18,12 @@ class TrajDataset(Dataset):
|
||||
def map_shape(self):
|
||||
return self.map.as_array.shape
|
||||
|
||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
||||
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs):
|
||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', normalized=True,
|
||||
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
|
||||
**kwargs):
|
||||
super(TrajDataset, self).__init__()
|
||||
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
|
||||
self.normalized = normalized
|
||||
self.preserve_equal_samples = preserve_equal_samples
|
||||
self.mode = mode
|
||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||
@@ -58,6 +60,10 @@ class TrajDataset(Dataset):
|
||||
trajectory = trajectory.draw_in_array(self.map_shape)
|
||||
alternative = alternative.draw_in_array(self.map_shape)
|
||||
if self.mode == 'separated_arrays':
|
||||
if self.normalized:
|
||||
map_array = map_array / V.WHITE
|
||||
trajectory = trajectory / V.WHITE
|
||||
alternative = alternative / V.WHITE
|
||||
return (map_array, trajectory, label), alternative
|
||||
else:
|
||||
return np.concatenate((map_array, trajectory, alternative)), label
|
||||
@@ -86,8 +92,9 @@ class TrajData(object):
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self, map_root, length=100000, mode='separated_arrays', **_):
|
||||
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, **_):
|
||||
|
||||
self.normalized = normalized
|
||||
self.mode = mode
|
||||
self.maps_root = Path(map_root)
|
||||
self.length = length
|
||||
@@ -100,7 +107,7 @@ class TrajData(object):
|
||||
# find max image size among available maps:
|
||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||
mode=self.mode, embedding_size=max_map_size,
|
||||
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
||||
preserve_equal_samples=True)
|
||||
for map_file in map_files])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user