Debugging

This commit is contained in:
Si11ium
2020-02-28 19:11:53 +01:00
parent 7b3f781d19
commit 44f6589259
18 changed files with 134 additions and 78 deletions

View File

@ -1,6 +1,6 @@
import shelve
from pathlib import Path
from typing import Union
from typing import Union, List
import torch
from random import choice
@ -17,7 +17,7 @@ class TrajDataset(Dataset):
return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100.000, all_in_map=True, **kwargs):
length=100000, all_in_map=True, **kwargs):
super(TrajDataset, self).__init__()
self.all_in_map = all_in_map
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
@ -34,11 +34,11 @@ class TrajDataset(Dataset):
alternative = self.map.generate_alternative(trajectory)
label = choice([0, 1])
if self.all_in_map:
blank_trajectory_space = torch.zeros_like(self.map.as_array)
blank_trajectory_space = torch.zeros(self.map.shape)
blank_trajectory_space[trajectory.vertices] = 1
blank_alternative_space = torch.zeros_like(self.map.as_array)
blank_alternative_space[trajectory.vertices] = 1
blank_alternative_space = torch.zeros(self.map.shape)
blank_alternative_space[trajectory.np_vertices] = 1
map_array = torch.as_tensor(self.map.as_array)
label = self.map.are_homotopic(trajectory, alternative)
@ -56,7 +56,7 @@ class TrajData(object):
@property
def map_shapes_max(self):
shapes = self.map_shapes
return map(max, zip(*shapes))
return list(map(max, zip(*shapes)))
@property
def name(self):
@ -66,12 +66,12 @@ class TrajData(object):
self.all_in_map = all_in_map
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
self._dataset = self._load_datasets()
self.length = length
self._dataset = self._load_datasets()
def _load_datasets(self):
map_files = list(self.maps_root.glob('*.bmp'))
equal_split = self.length // len(map_files)
equal_split = int(self.length // len(map_files))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
all_in_map=self.all_in_map) for map_image in map_files])