CNN Classifier
This commit is contained in:
parent
537e5371c9
commit
7b3f781d19
13
.idea/deployment.xml
generated
13
.idea/deployment.xml
generated
@ -1,11 +1,18 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="PublishConfigData" serverName="steffen@aimachine:22">
|
<component name="PublishConfigData" serverName="traj_gen-AiMachine">
|
||||||
<serverData>
|
<serverData>
|
||||||
<paths name="steffen@aimachine:22">
|
<paths name="ErLoWa-AiMachine">
|
||||||
<serverdata>
|
<serverdata>
|
||||||
<mappings>
|
<mappings>
|
||||||
<mapping deploy="\" local="$PROJECT_DIR$" />
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="traj_gen-AiMachine">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping deploy="/" local="$PROJECT_DIR$" web="/" />
|
||||||
</mappings>
|
</mappings>
|
||||||
</serverdata>
|
</serverdata>
|
||||||
</paths>
|
</paths>
|
||||||
|
@ -1,11 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from lib.objects.map import Map
|
|
||||||
from lib.preprocessing.generator import Generator
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
data_root = Path() / 'data'
|
|
||||||
maps_root = Path() / 'res' / 'maps'
|
|
||||||
map_object = Map('Tate').from_image(maps_root / 'tate_sw.bmp')
|
|
||||||
generator = Generator(data_root, map_object)
|
|
||||||
generator.generate_n_trajectories_m_alternatives(100, 10, 'test')
|
|
@ -3,17 +3,17 @@ from pathlib import Path
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from random import choice
|
from torch.utils.data import Dataset, ConcatDataset
|
||||||
from torch.utils.data import ConcatDataset, Dataset
|
|
||||||
|
|
||||||
|
from datasets.utils import DatasetMapping
|
||||||
|
from lib.modules.model_parts import Generator
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
from lib.preprocessing.generator import Generator
|
|
||||||
|
|
||||||
|
|
||||||
class TrajPairDataset(Dataset):
|
class TrajPairDataset(Dataset):
|
||||||
@property
|
@property
|
||||||
def map_shape(self):
|
def map_shape(self):
|
||||||
return self._dataset.map.as_array.shape
|
return self.map.as_array.shape
|
||||||
|
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
super(TrajPairDataset, self).__init__()
|
super(TrajPairDataset, self).__init__()
|
||||||
@ -30,19 +30,6 @@ class TrajPairDataset(Dataset):
|
|||||||
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item], self.mapname
|
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item], self.mapname
|
||||||
|
|
||||||
|
|
||||||
class DatasetMapping(Dataset):
|
|
||||||
|
|
||||||
def __init__(self, dataset: Union[TrajPairDataset, ConcatDataset], mapping):
|
|
||||||
self._dataset = dataset
|
|
||||||
self._mapping = mapping
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self._mapping.shape[0]
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return self._dataset[self._mapping[item]]
|
|
||||||
|
|
||||||
|
|
||||||
class TrajPairData(object):
|
class TrajPairData(object):
|
||||||
@property
|
@property
|
||||||
def map_shapes(self):
|
def map_shapes(self):
|
||||||
@ -111,38 +98,3 @@ class TrajPairData(object):
|
|||||||
|
|
||||||
def get_datasets(self):
|
def get_datasets(self):
|
||||||
return self.train_dataset, self.val_dataset, self.test_dataset
|
return self.train_dataset, self.val_dataset, self.test_dataset
|
||||||
|
|
||||||
|
|
||||||
class TrajDataset(Dataset):
|
|
||||||
|
|
||||||
def __init__(self, data_root, maps_root: Union[Path, str] = '', mapname='tate_sw', length=100.000, **_):
|
|
||||||
super(TrajDataset, self).__init__()
|
|
||||||
self.mapname = mapname
|
|
||||||
self.maps_root = maps_root
|
|
||||||
self.data_root = data_root
|
|
||||||
self._len = length
|
|
||||||
|
|
||||||
self._map_obj = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp')
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self._len
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
trajectory = self._map_obj.get_random_trajectory()
|
|
||||||
label = choice([0, 1])
|
|
||||||
return trajectory.vertices, None, label, self.mapname
|
|
||||||
|
|
||||||
@property
|
|
||||||
def train_dataset(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def val_dataset(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def test_dataset(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def get_datasets(self):
|
|
||||||
return self, self, self
|
|
91
datasets/trajectory_dataset.py
Normal file
91
datasets/trajectory_dataset.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import shelve
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from random import choice
|
||||||
|
from torch.utils.data import ConcatDataset, Dataset
|
||||||
|
|
||||||
|
from lib.objects.map import Map
|
||||||
|
from lib.preprocessing.generator import Generator
|
||||||
|
|
||||||
|
|
||||||
|
class TrajDataset(Dataset):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def map_shape(self):
|
||||||
|
return self.map.as_array.shape
|
||||||
|
|
||||||
|
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
||||||
|
length=100.000, all_in_map=True, **kwargs):
|
||||||
|
super(TrajDataset, self).__init__()
|
||||||
|
self.all_in_map = all_in_map
|
||||||
|
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||||
|
self.maps_root = maps_root
|
||||||
|
self._len = length
|
||||||
|
|
||||||
|
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._len
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
trajectory = self.map.get_random_trajectory()
|
||||||
|
alternative = self.map.generate_alternative(trajectory)
|
||||||
|
label = choice([0, 1])
|
||||||
|
if self.all_in_map:
|
||||||
|
blank_trajectory_space = torch.zeros_like(self.map.as_array)
|
||||||
|
blank_trajectory_space[trajectory.vertices] = 1
|
||||||
|
|
||||||
|
blank_alternative_space = torch.zeros_like(self.map.as_array)
|
||||||
|
blank_alternative_space[trajectory.vertices] = 1
|
||||||
|
|
||||||
|
map_array = torch.as_tensor(self.map.as_array)
|
||||||
|
label = self.map.are_homotopic(trajectory, alternative)
|
||||||
|
|
||||||
|
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), label
|
||||||
|
else:
|
||||||
|
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||||
|
|
||||||
|
|
||||||
|
class TrajData(object):
|
||||||
|
@property
|
||||||
|
def map_shapes(self):
|
||||||
|
return [dataset.map_shape for dataset in self._dataset.datasets]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def map_shapes_max(self):
|
||||||
|
shapes = self.map_shapes
|
||||||
|
return map(max, zip(*shapes))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_):
|
||||||
|
|
||||||
|
self.all_in_map = all_in_map
|
||||||
|
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
|
||||||
|
self._dataset = self._load_datasets()
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def _load_datasets(self):
|
||||||
|
map_files = list(self.maps_root.glob('*.bmp'))
|
||||||
|
equal_split = self.length // len(map_files)
|
||||||
|
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
|
||||||
|
all_in_map=self.all_in_map) for map_image in map_files])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train_dataset(self):
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def val_dataset(self):
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_dataset(self):
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
def get_datasets(self):
|
||||||
|
return self._dataset, self._dataset, self._dataset
|
18
datasets/utils.py
Normal file
18
datasets/utils.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset, ConcatDataset
|
||||||
|
|
||||||
|
from datasets.paired_dataset import TrajPairDataset
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMapping(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, dataset: Union[TrajPairDataset, ConcatDataset, Dataset], mapping):
|
||||||
|
self._dataset = dataset
|
||||||
|
self._mapping = mapping
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._mapping.shape[0]
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self._dataset[self._mapping[item]]
|
@ -33,13 +33,11 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
|
|
||||||
# NN Nodes
|
# NN Nodes
|
||||||
|
|
||||||
self.conv1 = ConvModule(self.in_shape, self.hparams.model_param.filters[0])
|
|
||||||
self.conv2 = ConvModule(self.conv1.shape, self.hparams.model_param.filters[0])
|
|
||||||
self.conv3 = ConvModule(self.conv2.shape, self.hparams.model_param.filters[0])
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self.conv2 = ConvModule(self.conv1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
self.conv3 = ConvModule(self.conv2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
|
conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
pass
|
pass
|
||||||
|
@ -1,12 +1,19 @@
|
|||||||
from lib.modules.utils import LightningBaseModule
|
import torch
|
||||||
from lib.modules.blocks import ConvModule
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
|
from datasets.trajectory_dataset import TrajData
|
||||||
|
from lib.modules.utils import LightningBaseModule, Flatten
|
||||||
|
from lib.modules.blocks import ConvModule, ResidualModule
|
||||||
|
|
||||||
|
|
||||||
class ConvHomDetector(LightningBaseModule):
|
class ConvHomDetector(LightningBaseModule):
|
||||||
|
|
||||||
name = 'CNNHomotopyClassifier'
|
name = 'CNNHomotopyClassifier'
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
pass
|
return Adam(self.parameters(), lr=self.lr)
|
||||||
|
|
||||||
def validation_step(self, *args, **kwargs):
|
def validation_step(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
@ -15,18 +22,60 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
pass
|
batch_x, batch_y = batch_xy
|
||||||
|
pred_y = self(batch_x)
|
||||||
|
loss = F.binary_cross_entropy(pred_y, batch_y)
|
||||||
|
return {'loss': loss, 'log': dict(loss=loss)}
|
||||||
|
|
||||||
def test_step(self, *args, **kwargs):
|
def test_step(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init__(self, *params):
|
def __init__(self, *params):
|
||||||
super(ConvHomDetector, self).__init__(*params)
|
super(ConvHomDetector, self).__init__(*params)
|
||||||
|
# Dataset
|
||||||
|
self.dataset = TrajData(self.hparams.data_param.data_root)
|
||||||
|
|
||||||
self.conv1 = ConvModule(self.dataset.map_shape
|
# Additional Attributes
|
||||||
|
self.map_shape = self.dataset.map_shapes_max
|
||||||
|
|
||||||
)
|
# NN Nodes
|
||||||
|
# ============================
|
||||||
|
# Convolutional Map Processing
|
||||||
|
#
|
||||||
|
self.map_res_1 = ResidualModule(self.in_shape, ConvModule, 3,
|
||||||
|
**dict(conv_kernel=3, conv_stride=1,
|
||||||
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
|
||||||
|
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1,
|
||||||
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 3,
|
||||||
|
**dict(conv_kernel=3, conv_stride=1,
|
||||||
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
|
||||||
|
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=5, conv_stride=1,
|
||||||
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 3,
|
||||||
|
**dict(conv_kernel=3, conv_stride=1,
|
||||||
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
|
||||||
|
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1,
|
||||||
|
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
|
||||||
|
|
||||||
|
self.flatten = Flatten(self.map_conv_3.shape)
|
||||||
|
|
||||||
|
# ============================
|
||||||
|
# Classifier
|
||||||
|
#
|
||||||
|
|
||||||
|
self.linear = nn.Linear(self.flatten.shape.item(), self.hparams.model_param.classes * 10)
|
||||||
|
self.classifier = nn.Linear(self.linear.shape, self.hparams.model_param.classes)
|
||||||
|
self.softmax = nn.Softmax()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
pass
|
tensor = self.map_res_1(x)
|
||||||
|
tensor = self.map_conv_1(tensor)
|
||||||
|
tensor = self.map_res_2(tensor)
|
||||||
|
tensor = self.map_conv_2(tensor)
|
||||||
|
tensor = self.map_conv_3(tensor)
|
||||||
|
tensor = self.flatten(tensor)
|
||||||
|
tensor = self.linear(tensor)
|
||||||
|
tensor = self.classifier(tensor)
|
||||||
|
tensor = self.softmax(tensor)
|
||||||
|
return tensor
|
||||||
|
@ -22,7 +22,7 @@ class ConvModule(nn.Module):
|
|||||||
return output.shape[1:]
|
return output.shape[1:]
|
||||||
|
|
||||||
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=True,
|
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=True,
|
||||||
dropout: Union[int, float] = 0,
|
dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
|
||||||
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
||||||
super(ConvModule, self).__init__()
|
super(ConvModule, self).__init__()
|
||||||
|
|
||||||
@ -39,9 +39,9 @@ class ConvModule(nn.Module):
|
|||||||
self.dropout = nn.Dropout2d(dropout) if dropout else False
|
self.dropout = nn.Dropout2d(dropout) if dropout else False
|
||||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else False
|
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else False
|
||||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else False
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else False
|
||||||
self.conv = nn.Conv2d(in_channels, conv_filters, conv_kernel, bias=use_bias,
|
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
|
||||||
padding=self.padding, stride=self.stride
|
padding=self.padding, stride=self.stride
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.norm(x) if self.norm else x
|
x = self.norm(x) if self.norm else x
|
||||||
@ -91,8 +91,30 @@ class DeConvModule(nn.Module):
|
|||||||
tensor = self.activation(tensor) if self.activation else tensor
|
tensor = self.activation(tensor) if self.activation else tensor
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def size(self):
|
|
||||||
return self.shape
|
class ResidualModule(nn.Module):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
|
||||||
|
def __init__(self, in_shape, module_class, n, **module_paramters):
|
||||||
|
assert n >= 1
|
||||||
|
super(ResidualModule, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
|
module_paramters.update(in_shape=in_shape)
|
||||||
|
self.residual_block = [module_class(**module_paramters) for x in range(n)]
|
||||||
|
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for module in self.residual_block:
|
||||||
|
tensor = module(x)
|
||||||
|
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
tensor = tensor + x
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class RecurrentModule(nn.Module):
|
class RecurrentModule(nn.Module):
|
||||||
|
@ -6,7 +6,6 @@ from torch import nn
|
|||||||
from torch import functional as F
|
from torch import functional as F
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from dataset.dataset import TrajDataset, TrajPairDataset
|
|
||||||
from lib.objects.map import MapStorage
|
from lib.objects.map import MapStorage
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
@ -17,8 +16,20 @@ import pytorch_lightning as pl
|
|||||||
|
|
||||||
|
|
||||||
class Flatten(nn.Module):
|
class Flatten(nn.Module):
|
||||||
def __init__(self, to=(-1, )):
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
try:
|
||||||
|
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||||
|
output = self(x)
|
||||||
|
return output.shape[1:]
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def __init__(self, in_shape, to=(-1, )):
|
||||||
super(Flatten, self).__init__()
|
super(Flatten, self).__init__()
|
||||||
|
self.in_shape = in_shape
|
||||||
self.to = to
|
self.to = to
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -110,6 +110,16 @@ class Map(object):
|
|||||||
dest = self.get_valid_position()
|
dest = self.get_valid_position()
|
||||||
return self.simple_trajectory_between(start, dest)
|
return self.simple_trajectory_between(start, dest)
|
||||||
|
|
||||||
|
def generate_alternative(self, trajectory, mode='one_patching'):
|
||||||
|
start, dest = trajectory.endpoints
|
||||||
|
if mode == 'one_patching':
|
||||||
|
patch = self.get_valid_position()
|
||||||
|
alternative = self.get_trajectory_from_vertices(start, patch, dest)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'mode checking went wrong...')
|
||||||
|
|
||||||
|
return alternative
|
||||||
|
|
||||||
def are_homotopic(self, trajectory, other_trajectory):
|
def are_homotopic(self, trajectory, other_trajectory):
|
||||||
if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
|
if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
|
||||||
raise TypeError
|
raise TypeError
|
||||||
@ -121,6 +131,7 @@ class Map(object):
|
|||||||
draw.polygon(polyline, outline=255, fill=255)
|
draw.polygon(polyline, outline=255, fill=255)
|
||||||
|
|
||||||
a = (np.array(img) * np.where(self.map_array == self.white, 0, 1)).sum()
|
a = (np.array(img) * np.where(self.map_array == self.white, 0, 1)).sum()
|
||||||
|
|
||||||
if a >= 1:
|
if a >= 1:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
@ -5,6 +5,8 @@ from collections import defaultdict
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
|
|
||||||
|
|
||||||
@ -26,11 +28,7 @@ class Generator:
|
|||||||
with mp.Pool(processes) as pool:
|
with mp.Pool(processes) as pool:
|
||||||
async_results = [pool.apply_async(self.generate_n_alternatives, kwds=kwargs) for _ in range(n)]
|
async_results = [pool.apply_async(self.generate_n_alternatives, kwds=kwargs) for _ in range(n)]
|
||||||
|
|
||||||
# for _ in trange(n, desc='Processing Trajectories'):
|
for result_obj in tqdm(async_results, total=n, desc='Producing trajectories with Alternatives'):
|
||||||
# self.write_n_alternatives(m, dataset_name, **kwargs)
|
|
||||||
|
|
||||||
# This line is for error catching only
|
|
||||||
for result_obj in async_results:
|
|
||||||
trajectory, alternatives, labels = result_obj.get()
|
trajectory, alternatives, labels = result_obj.get()
|
||||||
mutex.acquire()
|
mutex.acquire()
|
||||||
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
|
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
|
||||||
@ -44,24 +42,14 @@ class Generator:
|
|||||||
datafile.unlink()
|
datafile.unlink()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def generate_alternatives(self, trajectory, mode='one_patching'):
|
def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None, is_sub_process=False,
|
||||||
start, dest = trajectory.endpoints
|
mode='one_patching', equal_samples=True, binary_check=True):
|
||||||
if mode == 'one_patching':
|
|
||||||
patch = self.map.get_valid_position()
|
|
||||||
alternative = self.map.get_trajectory_from_vertices(start, patch, dest)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f'mode checking went wrong...')
|
|
||||||
|
|
||||||
return alternative
|
|
||||||
|
|
||||||
def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None,
|
|
||||||
mode='one_patching', equal_samples=True, binary_check=True):
|
|
||||||
assert n is not None, f'n is not allowed to be None but was: {n}'
|
assert n is not None, f'n is not allowed to be None but was: {n}'
|
||||||
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
||||||
|
|
||||||
trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
|
trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
|
||||||
|
|
||||||
results = [self.generate_alternatives(trajectory=trajectory, mode=mode) for _ in range(n)]
|
results = [self.map.generate_alternative(trajectory=trajectory, mode=mode) for _ in range(n)]
|
||||||
|
|
||||||
# label per homotopic class
|
# label per homotopic class
|
||||||
homotopy_classes = defaultdict(list)
|
homotopy_classes = defaultdict(list)
|
||||||
@ -91,9 +79,10 @@ class Generator:
|
|||||||
alternatives.extend(homotopy_classes[key])
|
alternatives.extend(homotopy_classes[key])
|
||||||
labels.extend([key] * len(homotopy_classes[key]))
|
labels.extend([key] * len(homotopy_classes[key]))
|
||||||
if datafile_name:
|
if datafile_name:
|
||||||
|
if is_sub_process:
|
||||||
|
datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
|
||||||
# Write to disk
|
# Write to disk
|
||||||
subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
|
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
|
||||||
self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
|
|
||||||
return trajectory, alternatives, labels
|
return trajectory, alternatives, labels
|
||||||
|
|
||||||
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
|
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
|
||||||
|
5
main.py
5
main.py
@ -10,7 +10,7 @@ import warnings
|
|||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from dataset.dataset import TrajPairData
|
from dataset.dataset import TrajDataset
|
||||||
from lib.utils.config import Config
|
from lib.utils.config import Config
|
||||||
from lib.utils.logging import Logger
|
from lib.utils.logging import Logger
|
||||||
|
|
||||||
@ -49,6 +49,7 @@ main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
|||||||
main_arg_parser.add_argument("--model_type", type=str, default="LeNetAE", help="")
|
main_arg_parser.add_argument("--model_type", type=str, default="LeNetAE", help="")
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
||||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
||||||
|
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
||||||
@ -66,7 +67,7 @@ config = Config.read_namespace(args)
|
|||||||
# TESTING ONLY #
|
# TESTING ONLY #
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
hparams = config.model_paramters
|
hparams = config.model_paramters
|
||||||
dataset = TrajPairData('data', mapname='tate', alternatives=10000, trajectories=2500)
|
dataset = TrajDataset('data', mapname='tate', alternatives=1000, trajectories=10000)
|
||||||
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
|
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
|
||||||
batch_size=hparams.data_param.batchsize,
|
batch_size=hparams.data_param.batchsize,
|
||||||
num_workers=hparams.data_param.worker)
|
num_workers=hparams.data_param.worker)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user