Transformer Implementation

This commit is contained in:
Si11ium
2020-10-29 16:40:43 +01:00
parent f296ba78b9
commit 13812b83b5
5 changed files with 167 additions and 66 deletions

View File

@ -8,90 +8,93 @@ from operator import mul
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl
# Utility - Modules
###################
from ..utils.model_io import ModelParameters
try:
import pytorch_lightning as pl
class LightningBaseModule(pl.LightningModule, ABC):
class LightningBaseModule(pl.LightningModule, ABC):
@classmethod
def name(cls):
return cls.__name__
@classmethod
def name(cls):
return cls.__name__
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
@property
def shape(self):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
except Exception as e:
print(e)
return -1
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
def __init__(self, hparams):
super(LightningBaseModule, self).__init__()
# Set Parameters
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
# Set Parameters
################################
self.hparams = hparams
self.params = ModelParameters(hparams)
# Dataset Loading
################################
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
def size(self):
return self.shape
def size(self):
return self.shape
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
with (model_path / 'model_class.obj').open('wb') as f:
torch.save(self.__class__, f)
return True
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
with (model_path / 'model_class.obj').open('wb') as f:
torch.save(self.__class__, f)
return True
@property
def data_len(self):
return len(self.dataset.train_dataset)
@property
def data_len(self):
return len(self.dataset.train_dataset)
@property
def n_train_batches(self):
return len(self.train_dataloader())
@property
def n_train_batches(self):
return len(self.train_dataloader())
def configure_optimizers(self):
raise NotImplementedError
def configure_optimizers(self):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def forward(self, *args, **kwargs):
raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_step(self, *args, **kwargs):
raise NotImplementedError
def test_epoch_end(self, outputs):
raise NotImplementedError
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
modules = [LightningBaseModule, nn.Module]
except ImportError:
modules = [nn.Module, ]
pass # Maybe post a hint to install pytorch-lightning.
class ShapeMixin:
@property
def shape(self):
assert isinstance(self, (LightningBaseModule, nn.Module))
assert isinstance(self, modules)
def get_out_shape(output):
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None:
in_shape = self.in_shape if hasattr(self, 'in_shape') else None
if in_shape is not None:
try:
device = self.device
except AttributeError:
@ -99,10 +102,11 @@ class ShapeMixin:
device = next(self.parameters()).device
except StopIteration:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(self.in_shape, device=device)
x = torch.randn(in_shape, device=device)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
# noinspection PyCallingNonCallable
y = self(x)
if isinstance(y, tuple):
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
@ -265,7 +269,7 @@ class Splitter(nn.Module):
self.autopad = AutoPadToShape(self._out_shape)
def forward(self, x: torch.Tensor):
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim
dim = self.dim + 1 if len(self.in_shape) == (x.ndim - 1) else self.dim
x = x.transpose(0, dim)
n_blocks = list()
for block_idx in range(self.n):