all debug and train running
This commit is contained in:
@ -3,7 +3,6 @@ from pathlib import Path
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
from random import choice
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
|
||||
from lib.objects.map import Map
|
||||
@ -17,12 +16,14 @@ class TrajDataset(Dataset):
|
||||
return self.map.as_array.shape
|
||||
|
||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
||||
length=100000, all_in_map=True, embedding_size=None, **kwargs):
|
||||
length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs):
|
||||
super(TrajDataset, self).__init__()
|
||||
self.preserve_equal_samples = preserve_equal_samples
|
||||
self.all_in_map = all_in_map
|
||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||
self.maps_root = maps_root
|
||||
self._len = length
|
||||
self.last_label = -1
|
||||
|
||||
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
||||
|
||||
@ -31,8 +32,16 @@ class TrajDataset(Dataset):
|
||||
|
||||
def __getitem__(self, item):
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
alternative = self.map.generate_alternative(trajectory)
|
||||
label = choice([0, 1])
|
||||
while True:
|
||||
# TODO: Sanity Check this while true loop...
|
||||
alternative = self.map.generate_alternative(trajectory)
|
||||
label = self.map.are_homotopic(trajectory, alternative)
|
||||
if self.preserve_equal_samples and label == self.last_label:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
self.last_label = label
|
||||
if self.all_in_map:
|
||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
||||
blank_alternative_space = torch.zeros(self.map.shape)
|
||||
@ -41,7 +50,6 @@ class TrajDataset(Dataset):
|
||||
blank_alternative_space[index] = 1
|
||||
|
||||
map_array = torch.as_tensor(self.map.as_array).float()
|
||||
label = self.map.are_homotopic(trajectory, alternative)
|
||||
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
|
||||
else:
|
||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||
@ -78,7 +86,8 @@ class TrajData(object):
|
||||
# find max image size among available maps:
|
||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||
all_in_map=self.all_in_map, embedding_size=max_map_size)
|
||||
all_in_map=self.all_in_map, embedding_size=max_map_size,
|
||||
preserve_equal_samples=True)
|
||||
for map_file in map_files])
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user