CNN Classifier

This commit is contained in:
Si11ium
2020-02-21 09:44:09 +01:00
parent 537e5371c9
commit 7b3f781d19
12 changed files with 247 additions and 109 deletions

View File

@@ -33,13 +33,11 @@ class CNNRouteGeneratorModel(LightningBaseModule):
# NN Nodes
self.conv1 = ConvModule(self.in_shape, self.hparams.model_param.filters[0])
self.conv2 = ConvModule(self.conv1.shape, self.hparams.model_param.filters[0])
self.conv3 = ConvModule(self.conv2.shape, self.hparams.model_param.filters[0])
self.conv2 = ConvModule(self.conv1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
self.conv3 = ConvModule(self.conv2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0])
def forward(self, x):
pass

View File

@@ -1,12 +1,19 @@
from lib.modules.utils import LightningBaseModule
from lib.modules.blocks import ConvModule
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.modules.utils import LightningBaseModule, Flatten
from lib.modules.blocks import ConvModule, ResidualModule
class ConvHomDetector(LightningBaseModule):
name = 'CNNHomotopyClassifier'
def configure_optimizers(self):
pass
return Adam(self.parameters(), lr=self.lr)
def validation_step(self, *args, **kwargs):
pass
@@ -15,18 +22,60 @@ class ConvHomDetector(LightningBaseModule):
pass
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
pass
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
loss = F.binary_cross_entropy(pred_y, batch_y)
return {'loss': loss, 'log': dict(loss=loss)}
def test_step(self, *args, **kwargs):
pass
def __init__(self, *params):
super(ConvHomDetector, self).__init__(*params)
# Dataset
self.dataset = TrajData(self.hparams.data_param.data_root)
self.conv1 = ConvModule(self.dataset.map_shape
# Additional Attributes
self.map_shape = self.dataset.map_shapes_max
)
# NN Nodes
# ============================
# Convolutional Map Processing
#
self.map_res_1 = ResidualModule(self.in_shape, ConvModule, 3,
**dict(conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 3,
**dict(conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=5, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 3,
**dict(conv_kernel=3, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0]))
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1,
conv_padding=0, conv_filters=self.hparams.model_param.filters[0])
self.flatten = Flatten(self.map_conv_3.shape)
# ============================
# Classifier
#
self.linear = nn.Linear(self.flatten.shape.item(), self.hparams.model_param.classes * 10)
self.classifier = nn.Linear(self.linear.shape, self.hparams.model_param.classes)
self.softmax = nn.Softmax()
def forward(self, x):
pass
tensor = self.map_res_1(x)
tensor = self.map_conv_1(tensor)
tensor = self.map_res_2(tensor)
tensor = self.map_conv_2(tensor)
tensor = self.map_conv_3(tensor)
tensor = self.flatten(tensor)
tensor = self.linear(tensor)
tensor = self.classifier(tensor)
tensor = self.softmax(tensor)
return tensor

View File

@@ -22,7 +22,7 @@ class ConvModule(nn.Module):
return output.shape[1:]
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=True,
dropout: Union[int, float] = 0,
dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__()
@@ -39,9 +39,9 @@ class ConvModule(nn.Module):
self.dropout = nn.Dropout2d(dropout) if dropout else False
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else False
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else False
self.conv = nn.Conv2d(in_channels, conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride
)
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride
)
def forward(self, x):
x = self.norm(x) if self.norm else x
@@ -91,8 +91,30 @@ class DeConvModule(nn.Module):
tensor = self.activation(tensor) if self.activation else tensor
return tensor
def size(self):
return self.shape
class ResidualModule(nn.Module):
@property
def shape(self):
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, module_class, n, **module_paramters):
assert n >= 1
super(ResidualModule, self).__init__()
self.in_shape = in_shape
module_paramters.update(in_shape=in_shape)
self.residual_block = [module_class(**module_paramters) for x in range(n)]
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x):
for module in self.residual_block:
tensor = module(x)
# noinspection PyUnboundLocalVariable
tensor = tensor + x
return tensor
class RecurrentModule(nn.Module):

View File

@@ -6,7 +6,6 @@ from torch import nn
from torch import functional as F
from torch.utils.data import DataLoader
from dataset.dataset import TrajDataset, TrajPairDataset
from lib.objects.map import MapStorage
import pytorch_lightning as pl
@@ -17,8 +16,20 @@ import pytorch_lightning as pl
class Flatten(nn.Module):
def __init__(self, to=(-1, )):
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
def __init__(self, in_shape, to=(-1, )):
super(Flatten, self).__init__()
self.in_shape = in_shape
self.to = to
def forward(self, x):

View File

@@ -110,6 +110,16 @@ class Map(object):
dest = self.get_valid_position()
return self.simple_trajectory_between(start, dest)
def generate_alternative(self, trajectory, mode='one_patching'):
start, dest = trajectory.endpoints
if mode == 'one_patching':
patch = self.get_valid_position()
alternative = self.get_trajectory_from_vertices(start, patch, dest)
else:
raise RuntimeError(f'mode checking went wrong...')
return alternative
def are_homotopic(self, trajectory, other_trajectory):
if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
raise TypeError
@@ -121,6 +131,7 @@ class Map(object):
draw.polygon(polyline, outline=255, fill=255)
a = (np.array(img) * np.where(self.map_array == self.white, 0, 1)).sum()
if a >= 1:
return False
else:

View File

@@ -5,6 +5,8 @@ from collections import defaultdict
from pathlib import Path
from tqdm import tqdm
from lib.objects.map import Map
@@ -26,11 +28,7 @@ class Generator:
with mp.Pool(processes) as pool:
async_results = [pool.apply_async(self.generate_n_alternatives, kwds=kwargs) for _ in range(n)]
# for _ in trange(n, desc='Processing Trajectories'):
# self.write_n_alternatives(m, dataset_name, **kwargs)
# This line is for error catching only
for result_obj in async_results:
for result_obj in tqdm(async_results, total=n, desc='Producing trajectories with Alternatives'):
trajectory, alternatives, labels = result_obj.get()
mutex.acquire()
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
@@ -44,24 +42,14 @@ class Generator:
datafile.unlink()
pass
def generate_alternatives(self, trajectory, mode='one_patching'):
start, dest = trajectory.endpoints
if mode == 'one_patching':
patch = self.map.get_valid_position()
alternative = self.map.get_trajectory_from_vertices(start, patch, dest)
else:
raise RuntimeError(f'mode checking went wrong...')
return alternative
def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None,
mode='one_patching', equal_samples=True, binary_check=True):
def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None, is_sub_process=False,
mode='one_patching', equal_samples=True, binary_check=True):
assert n is not None, f'n is not allowed to be None but was: {n}'
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
results = [self.generate_alternatives(trajectory=trajectory, mode=mode) for _ in range(n)]
results = [self.map.generate_alternative(trajectory=trajectory, mode=mode) for _ in range(n)]
# label per homotopic class
homotopy_classes = defaultdict(list)
@@ -91,9 +79,10 @@ class Generator:
alternatives.extend(homotopy_classes[key])
labels.extend([key] * len(homotopy_classes[key]))
if datafile_name:
if is_sub_process:
datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
# Write to disk
subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
return trajectory, alternatives, labels
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):