project Refactor, CNN Classifier Basics

This commit is contained in:
Steffen Illium
2020-03-08 23:46:02 +01:00
parent 75e8a61628
commit cd4fdf2de3
20 changed files with 441 additions and 239 deletions

View File

@@ -1,11 +1,7 @@
from abc import ABC
from pathlib import Path
from typing import Union
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from lib.modules.utils import AutoPad, Interpolate
#
@@ -26,12 +22,12 @@ class ConvModule(nn.Module):
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__()
# Module Paramters
# Module Parameters
self.in_shape = in_shape
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.activation = activation()
# Convolution Paramters
# Convolution Parameters
self.padding = conv_padding
self.stride = conv_stride
@@ -44,7 +40,7 @@ class ConvModule(nn.Module):
)
def forward(self, x):
x = self.norm(x) if self.norm else x
x = self.norm(x)
tensor = self.conv(x)
tensor = self.dropout(tensor)
@@ -72,10 +68,10 @@ class DeConvModule(nn.Module):
self.in_shape = in_shape
self.conv_filters = conv_filters
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride)
@@ -100,13 +96,13 @@ class ResidualModule(nn.Module):
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, module_class, n, activation=None, **module_paramters):
def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
assert n >= 1
super(ResidualModule, self).__init__()
self.in_shape = in_shape
module_paramters.update(in_shape=in_shape)
module_parameters.update(in_shape=in_shape)
self.activation = activation() if activation else lambda x: x
self.residual_block = nn.ModuleList([module_class(**module_paramters) for _ in range(n)])
self.residual_block = nn.ModuleList([module_class(**module_parameters) for _ in range(n)])
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
def forward(self, x):
@@ -143,5 +139,3 @@ class RecurrentModule(nn.Module):
def forward(self, x):
tensor = self.rnn(x)
return tensor

View File

@@ -1,8 +1,11 @@
from typing import List
import torch
from torch import nn
from lib.modules.utils import FlipTensor
from lib.objects.map import MapStorage
from lib.objects.map import MapStorage, Map
from lib.objects.trajectory import Trajectory
class BinaryHomotopicLoss(nn.Module):
@@ -11,7 +14,10 @@ class BinaryHomotopicLoss(nn.Module):
self.map_storage = map_storage
self.flipper = FlipTensor()
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
y_flipepd = self.flipper(y)
circle = torch.cat((x, y_flipepd), dim=-1)
masp = self.map_storage[mapnames].are
def forward(self, x: torch.Tensor, y: torch.Tensor, mapnames: str):
maps: List[Map] = [self.map_storage[mapname] for mapname in mapnames]
for basemap in maps:
basemap = basemap.as_2d_array

View File

@@ -83,9 +83,9 @@ class LightningBaseModule(pl.LightningModule, ABC):
print(e)
return -1
def __init__(self, params):
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
self.hparams = params
self.hparams = hparams
# Data loading
# =============================================================================
@@ -109,6 +109,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
def data_len(self):
return len(self.dataset.train_dataset)
@property
def n_train_batches(self):
return len(self.train_dataloader())
def configure_optimizers(self):
raise NotImplementedError
@@ -121,7 +125,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_end(self, outputs):
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self):
@@ -134,6 +138,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
m.bias.data.fill_(0.01)
self.apply(_weight_init)
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
class FilterLayer(nn.Module):