CNN Classifier
This commit is contained in:
100
datasets/paired_dataset.py
Normal file
100
datasets/paired_dataset.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import shelve
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
|
||||
from datasets.utils import DatasetMapping
|
||||
from lib.modules.model_parts import Generator
|
||||
from lib.objects.map import Map
|
||||
|
||||
|
||||
class TrajPairDataset(Dataset):
|
||||
@property
|
||||
def map_shape(self):
|
||||
return self.map.as_array.shape
|
||||
|
||||
def __init__(self, data):
|
||||
super(TrajPairDataset, self).__init__()
|
||||
self.alternatives = data['alternatives']
|
||||
self.trajectory = data['trajectory']
|
||||
self.labels = data['labels']
|
||||
self.mapname = data['map']['name'][4:] if data['map']['name'].startswith('map_') else data['map']['name']
|
||||
self.map = data['map']['map']
|
||||
|
||||
def __len__(self):
|
||||
return len(self.alternatives)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item], self.mapname
|
||||
|
||||
|
||||
class TrajPairData(object):
|
||||
@property
|
||||
def map_shapes(self):
|
||||
return [dataset.map_shape for dataset in self._dataset.datasets]
|
||||
|
||||
@property
|
||||
def map_shapes_max(self):
|
||||
shapes = self.map_shapes
|
||||
return map(max, zip(*shapes))
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self, data_root, map_root: Union[Path, str] = '', mapname='tate_sw', trajectories=1000, alternatives=10,
|
||||
train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_):
|
||||
|
||||
self.rebuild = rebuild
|
||||
self.equal_samples = equal_samples
|
||||
self._alternatives = alternatives
|
||||
self._trajectories = trajectories
|
||||
self.mapname = mapname
|
||||
self.train_split, self.val_split, self.test_split = train_val_test_split
|
||||
self.data_root = Path(data_root)
|
||||
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
|
||||
self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset()
|
||||
|
||||
def _build_data_on_demand(self):
|
||||
map_object = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp')
|
||||
assert self.maps_root.exists()
|
||||
dataset_file = Path(self.data_root) / f'{self.mapname}.pik'
|
||||
if dataset_file.exists() and self.rebuild:
|
||||
dataset_file.unlink()
|
||||
if not dataset_file.exists():
|
||||
generator = Generator(self.data_root, map_object)
|
||||
generator.generate_n_trajectories_m_alternatives(self._trajectories, self._alternatives,
|
||||
self.mapname, equal_samples=self.equal_samples)
|
||||
return True
|
||||
|
||||
def _load_dataset(self):
|
||||
assert self._build_data_on_demand()
|
||||
with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d:
|
||||
dataset = ConcatDataset([TrajPairDataset(d[key]) for key in d.keys() if key != 'map'])
|
||||
indices = torch.randperm(len(dataset))
|
||||
|
||||
train_size = int(len(dataset) * self.train_split)
|
||||
val_size = int(len(dataset) * self.val_split)
|
||||
test_size = int(len(dataset) * self.test_split)
|
||||
|
||||
train_map = indices[:train_size]
|
||||
val_map = indices[train_size:val_size]
|
||||
test_map = indices[test_size:]
|
||||
return dataset, train_map, val_map, test_map
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return DatasetMapping(self._dataset, self._train_map)
|
||||
|
||||
@property
|
||||
def val_dataset(self):
|
||||
return DatasetMapping(self._dataset, self._val_map)
|
||||
|
||||
@property
|
||||
def test_dataset(self):
|
||||
return DatasetMapping(self._dataset, self._test_map)
|
||||
|
||||
def get_datasets(self):
|
||||
return self.train_dataset, self.val_dataset, self.test_dataset
|
||||
91
datasets/trajectory_dataset.py
Normal file
91
datasets/trajectory_dataset.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import shelve
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from random import choice
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
|
||||
from lib.objects.map import Map
|
||||
from lib.preprocessing.generator import Generator
|
||||
|
||||
|
||||
class TrajDataset(Dataset):
|
||||
|
||||
@property
|
||||
def map_shape(self):
|
||||
return self.map.as_array.shape
|
||||
|
||||
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
||||
length=100.000, all_in_map=True, **kwargs):
|
||||
super(TrajDataset, self).__init__()
|
||||
self.all_in_map = all_in_map
|
||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||
self.maps_root = maps_root
|
||||
self._len = length
|
||||
|
||||
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname)
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __getitem__(self, item):
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
alternative = self.map.generate_alternative(trajectory)
|
||||
label = choice([0, 1])
|
||||
if self.all_in_map:
|
||||
blank_trajectory_space = torch.zeros_like(self.map.as_array)
|
||||
blank_trajectory_space[trajectory.vertices] = 1
|
||||
|
||||
blank_alternative_space = torch.zeros_like(self.map.as_array)
|
||||
blank_alternative_space[trajectory.vertices] = 1
|
||||
|
||||
map_array = torch.as_tensor(self.map.as_array)
|
||||
label = self.map.are_homotopic(trajectory, alternative)
|
||||
|
||||
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), label
|
||||
else:
|
||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||
|
||||
|
||||
class TrajData(object):
|
||||
@property
|
||||
def map_shapes(self):
|
||||
return [dataset.map_shape for dataset in self._dataset.datasets]
|
||||
|
||||
@property
|
||||
def map_shapes_max(self):
|
||||
shapes = self.map_shapes
|
||||
return map(max, zip(*shapes))
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_):
|
||||
|
||||
self.all_in_map = all_in_map
|
||||
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
|
||||
self._dataset = self._load_datasets()
|
||||
self.length = length
|
||||
|
||||
def _load_datasets(self):
|
||||
map_files = list(self.maps_root.glob('*.bmp'))
|
||||
equal_split = self.length // len(map_files)
|
||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
|
||||
all_in_map=self.all_in_map) for map_image in map_files])
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self._dataset
|
||||
|
||||
@property
|
||||
def val_dataset(self):
|
||||
return self._dataset
|
||||
|
||||
@property
|
||||
def test_dataset(self):
|
||||
return self._dataset
|
||||
|
||||
def get_datasets(self):
|
||||
return self._dataset, self._dataset, self._dataset
|
||||
18
datasets/utils.py
Normal file
18
datasets/utils.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Union
|
||||
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
|
||||
from datasets.paired_dataset import TrajPairDataset
|
||||
|
||||
|
||||
class DatasetMapping(Dataset):
|
||||
|
||||
def __init__(self, dataset: Union[TrajPairDataset, ConcatDataset, Dataset], mapping):
|
||||
self._dataset = dataset
|
||||
self._mapping = mapping
|
||||
|
||||
def __len__(self):
|
||||
return self._mapping.shape[0]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._dataset[self._mapping[item]]
|
||||
Reference in New Issue
Block a user