CNN Classifier

This commit is contained in:
Si11ium
2020-02-21 09:44:09 +01:00
parent 537e5371c9
commit 7b3f781d19
12 changed files with 247 additions and 109 deletions

100
datasets/paired_dataset.py Normal file
View File

@@ -0,0 +1,100 @@
import shelve
from pathlib import Path
from typing import Union
import torch
from torch.utils.data import Dataset, ConcatDataset
from datasets.utils import DatasetMapping
from lib.modules.model_parts import Generator
from lib.objects.map import Map
class TrajPairDataset(Dataset):
@property
def map_shape(self):
return self.map.as_array.shape
def __init__(self, data):
super(TrajPairDataset, self).__init__()
self.alternatives = data['alternatives']
self.trajectory = data['trajectory']
self.labels = data['labels']
self.mapname = data['map']['name'][4:] if data['map']['name'].startswith('map_') else data['map']['name']
self.map = data['map']['map']
def __len__(self):
return len(self.alternatives)
def __getitem__(self, item):
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item], self.mapname
class TrajPairData(object):
@property
def map_shapes(self):
return [dataset.map_shape for dataset in self._dataset.datasets]
@property
def map_shapes_max(self):
shapes = self.map_shapes
return map(max, zip(*shapes))
@property
def name(self):
return self.__class__.__name__
def __init__(self, data_root, map_root: Union[Path, str] = '', mapname='tate_sw', trajectories=1000, alternatives=10,
train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_):
self.rebuild = rebuild
self.equal_samples = equal_samples
self._alternatives = alternatives
self._trajectories = trajectories
self.mapname = mapname
self.train_split, self.val_split, self.test_split = train_val_test_split
self.data_root = Path(data_root)
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset()
def _build_data_on_demand(self):
map_object = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp')
assert self.maps_root.exists()
dataset_file = Path(self.data_root) / f'{self.mapname}.pik'
if dataset_file.exists() and self.rebuild:
dataset_file.unlink()
if not dataset_file.exists():
generator = Generator(self.data_root, map_object)
generator.generate_n_trajectories_m_alternatives(self._trajectories, self._alternatives,
self.mapname, equal_samples=self.equal_samples)
return True
def _load_dataset(self):
assert self._build_data_on_demand()
with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d:
dataset = ConcatDataset([TrajPairDataset(d[key]) for key in d.keys() if key != 'map'])
indices = torch.randperm(len(dataset))
train_size = int(len(dataset) * self.train_split)
val_size = int(len(dataset) * self.val_split)
test_size = int(len(dataset) * self.test_split)
train_map = indices[:train_size]
val_map = indices[train_size:val_size]
test_map = indices[test_size:]
return dataset, train_map, val_map, test_map
@property
def train_dataset(self):
return DatasetMapping(self._dataset, self._train_map)
@property
def val_dataset(self):
return DatasetMapping(self._dataset, self._val_map)
@property
def test_dataset(self):
return DatasetMapping(self._dataset, self._test_map)
def get_datasets(self):
return self.train_dataset, self.val_dataset, self.test_dataset

View File

@@ -0,0 +1,91 @@
import shelve
from pathlib import Path
from typing import Union
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map
from lib.preprocessing.generator import Generator
class TrajDataset(Dataset):
@property
def map_shape(self):
return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100.000, all_in_map=True, **kwargs):
super(TrajDataset, self).__init__()
self.all_in_map = all_in_map
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname)
def __len__(self):
return self._len
def __getitem__(self, item):
trajectory = self.map.get_random_trajectory()
alternative = self.map.generate_alternative(trajectory)
label = choice([0, 1])
if self.all_in_map:
blank_trajectory_space = torch.zeros_like(self.map.as_array)
blank_trajectory_space[trajectory.vertices] = 1
blank_alternative_space = torch.zeros_like(self.map.as_array)
blank_alternative_space[trajectory.vertices] = 1
map_array = torch.as_tensor(self.map.as_array)
label = self.map.are_homotopic(trajectory, alternative)
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), label
else:
return trajectory.vertices, alternative.vertices, label, self.mapname
class TrajData(object):
@property
def map_shapes(self):
return [dataset.map_shape for dataset in self._dataset.datasets]
@property
def map_shapes_max(self):
shapes = self.map_shapes
return map(max, zip(*shapes))
@property
def name(self):
return self.__class__.__name__
def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_):
self.all_in_map = all_in_map
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
self._dataset = self._load_datasets()
self.length = length
def _load_datasets(self):
map_files = list(self.maps_root.glob('*.bmp'))
equal_split = self.length // len(map_files)
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
all_in_map=self.all_in_map) for map_image in map_files])
@property
def train_dataset(self):
return self._dataset
@property
def val_dataset(self):
return self._dataset
@property
def test_dataset(self):
return self._dataset
def get_datasets(self):
return self._dataset, self._dataset, self._dataset

18
datasets/utils.py Normal file
View File

@@ -0,0 +1,18 @@
from typing import Union
from torch.utils.data import Dataset, ConcatDataset
from datasets.paired_dataset import TrajPairDataset
class DatasetMapping(Dataset):
def __init__(self, dataset: Union[TrajPairDataset, ConcatDataset, Dataset], mapping):
self._dataset = dataset
self._mapping = mapping
def __len__(self):
return self._mapping.shape[0]
def __getitem__(self, item):
return self._dataset[self._mapping[item]]