Debugging
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
import shelve
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
from random import choice
|
||||
@ -17,7 +17,7 @@ class TrajDataset(Dataset):
|
||||
return self.map.as_array.shape
|
||||
|
||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
||||
length=100.000, all_in_map=True, **kwargs):
|
||||
length=100000, all_in_map=True, **kwargs):
|
||||
super(TrajDataset, self).__init__()
|
||||
self.all_in_map = all_in_map
|
||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||
@ -34,11 +34,11 @@ class TrajDataset(Dataset):
|
||||
alternative = self.map.generate_alternative(trajectory)
|
||||
label = choice([0, 1])
|
||||
if self.all_in_map:
|
||||
blank_trajectory_space = torch.zeros_like(self.map.as_array)
|
||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||
blank_trajectory_space[trajectory.vertices] = 1
|
||||
|
||||
blank_alternative_space = torch.zeros_like(self.map.as_array)
|
||||
blank_alternative_space[trajectory.vertices] = 1
|
||||
blank_alternative_space = torch.zeros(self.map.shape)
|
||||
blank_alternative_space[trajectory.np_vertices] = 1
|
||||
|
||||
map_array = torch.as_tensor(self.map.as_array)
|
||||
label = self.map.are_homotopic(trajectory, alternative)
|
||||
@ -56,7 +56,7 @@ class TrajData(object):
|
||||
@property
|
||||
def map_shapes_max(self):
|
||||
shapes = self.map_shapes
|
||||
return map(max, zip(*shapes))
|
||||
return list(map(max, zip(*shapes)))
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -66,12 +66,12 @@ class TrajData(object):
|
||||
|
||||
self.all_in_map = all_in_map
|
||||
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
|
||||
self._dataset = self._load_datasets()
|
||||
self.length = length
|
||||
self._dataset = self._load_datasets()
|
||||
|
||||
def _load_datasets(self):
|
||||
map_files = list(self.maps_root.glob('*.bmp'))
|
||||
equal_split = self.length // len(map_files)
|
||||
equal_split = int(self.length // len(map_files))
|
||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
|
||||
all_in_map=self.all_in_map) for map_image in map_files])
|
||||
|
||||
|
Reference in New Issue
Block a user