project Refactor, CNN Classifier Basics
This commit is contained in:
@@ -1,11 +1,7 @@
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
from lib.modules.utils import AutoPad, Interpolate
|
||||
|
||||
#
|
||||
@@ -26,12 +22,12 @@ class ConvModule(nn.Module):
|
||||
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
||||
super(ConvModule, self).__init__()
|
||||
|
||||
# Module Paramters
|
||||
# Module Parameters
|
||||
self.in_shape = in_shape
|
||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||
self.activation = activation()
|
||||
|
||||
# Convolution Paramters
|
||||
# Convolution Parameters
|
||||
self.padding = conv_padding
|
||||
self.stride = conv_stride
|
||||
|
||||
@@ -44,7 +40,7 @@ class ConvModule(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x) if self.norm else x
|
||||
x = self.norm(x)
|
||||
|
||||
tensor = self.conv(x)
|
||||
tensor = self.dropout(tensor)
|
||||
@@ -72,10 +68,10 @@ class DeConvModule(nn.Module):
|
||||
self.in_shape = in_shape
|
||||
self.conv_filters = conv_filters
|
||||
|
||||
self.autopad = AutoPad() if autopad else lambda x: x
|
||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||
self.autopad = AutoPad() if autopad else lambda x: x
|
||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
|
||||
padding=self.padding, stride=self.stride)
|
||||
|
||||
@@ -100,13 +96,13 @@ class ResidualModule(nn.Module):
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
|
||||
def __init__(self, in_shape, module_class, n, activation=None, **module_paramters):
|
||||
def __init__(self, in_shape, module_class, n, activation=None, **module_parameters):
|
||||
assert n >= 1
|
||||
super(ResidualModule, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
module_paramters.update(in_shape=in_shape)
|
||||
module_parameters.update(in_shape=in_shape)
|
||||
self.activation = activation() if activation else lambda x: x
|
||||
self.residual_block = nn.ModuleList([module_class(**module_paramters) for _ in range(n)])
|
||||
self.residual_block = nn.ModuleList([module_class(**module_parameters) for _ in range(n)])
|
||||
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||
|
||||
def forward(self, x):
|
||||
@@ -143,5 +139,3 @@ class RecurrentModule(nn.Module):
|
||||
def forward(self, x):
|
||||
tensor = self.rnn(x)
|
||||
return tensor
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lib.modules.utils import FlipTensor
|
||||
from lib.objects.map import MapStorage
|
||||
from lib.objects.map import MapStorage, Map
|
||||
from lib.objects.trajectory import Trajectory
|
||||
|
||||
|
||||
class BinaryHomotopicLoss(nn.Module):
|
||||
@@ -11,7 +14,10 @@ class BinaryHomotopicLoss(nn.Module):
|
||||
self.map_storage = map_storage
|
||||
self.flipper = FlipTensor()
|
||||
|
||||
def forward(self, x:torch.Tensor, y: torch.Tensor, mapnames: str):
|
||||
y_flipepd = self.flipper(y)
|
||||
circle = torch.cat((x, y_flipepd), dim=-1)
|
||||
masp = self.map_storage[mapnames].are
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor, mapnames: str):
|
||||
maps: List[Map] = [self.map_storage[mapname] for mapname in mapnames]
|
||||
for basemap in maps:
|
||||
basemap = basemap.as_2d_array
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -83,9 +83,9 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, params):
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
self.hparams = params
|
||||
self.hparams = hparams
|
||||
|
||||
# Data loading
|
||||
# =============================================================================
|
||||
@@ -109,6 +109,10 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
|
||||
@property
|
||||
def n_train_batches(self):
|
||||
return len(self.train_dataloader())
|
||||
|
||||
def configure_optimizers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -121,7 +125,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
def test_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_end(self, outputs):
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_weights(self):
|
||||
@@ -134,6 +138,26 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
m.bias.data.fill_(0.01)
|
||||
self.apply(_weight_init)
|
||||
|
||||
# Dataloaders
|
||||
# ================================================================================
|
||||
# Train Dataloader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
|
||||
class FilterLayer(nn.Module):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user