Transformer Implementation
This commit is contained in:
116
modules/util.py
116
modules/util.py
@ -8,90 +8,93 @@ from operator import mul
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
from ..utils.model_io import ModelParameters
|
||||
|
||||
try:
|
||||
import pytorch_lightning as pl
|
||||
|
||||
class LightningBaseModule(pl.LightningModule, ABC):
|
||||
class LightningBaseModule(pl.LightningModule, ABC):
|
||||
|
||||
@classmethod
|
||||
def name(cls):
|
||||
return cls.__name__
|
||||
@classmethod
|
||||
def name(cls):
|
||||
return cls.__name__
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
try:
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return -1
|
||||
@property
|
||||
def shape(self):
|
||||
try:
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
def __init__(self, hparams):
|
||||
super(LightningBaseModule, self).__init__()
|
||||
|
||||
# Set Parameters
|
||||
################################
|
||||
self.hparams = hparams
|
||||
self.params = ModelParameters(hparams)
|
||||
# Set Parameters
|
||||
################################
|
||||
self.hparams = hparams
|
||||
self.params = ModelParameters(hparams)
|
||||
|
||||
# Dataset Loading
|
||||
################################
|
||||
# TODO: Find a way to push Class Name, library path and parameters (sometimes those are objects) in here
|
||||
def size(self):
|
||||
return self.shape
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
def save_to_disk(self, model_path):
|
||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||
if not (model_path / 'model_class.obj').exists():
|
||||
with (model_path / 'model_class.obj').open('wb') as f:
|
||||
torch.save(self.__class__, f)
|
||||
return True
|
||||
|
||||
def save_to_disk(self, model_path):
|
||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||
if not (model_path / 'model_class.obj').exists():
|
||||
with (model_path / 'model_class.obj').open('wb') as f:
|
||||
torch.save(self.__class__, f)
|
||||
return True
|
||||
@property
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
|
||||
@property
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
@property
|
||||
def n_train_batches(self):
|
||||
return len(self.train_dataloader())
|
||||
|
||||
@property
|
||||
def n_train_batches(self):
|
||||
return len(self.train_dataloader())
|
||||
def configure_optimizers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def configure_optimizers(self):
|
||||
raise NotImplementedError
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def test_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
self.apply(weight_initializer)
|
||||
modules = [LightningBaseModule, nn.Module]
|
||||
|
||||
except ImportError:
|
||||
modules = [nn.Module, ]
|
||||
pass # Maybe post a hint to install pytorch-lightning.
|
||||
|
||||
|
||||
class ShapeMixin:
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
assert isinstance(self, (LightningBaseModule, nn.Module))
|
||||
|
||||
assert isinstance(self, modules)
|
||||
|
||||
def get_out_shape(output):
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
|
||||
if self.in_shape is not None:
|
||||
in_shape = self.in_shape if hasattr(self, 'in_shape') else None
|
||||
if in_shape is not None:
|
||||
try:
|
||||
device = self.device
|
||||
except AttributeError:
|
||||
@ -99,10 +102,11 @@ class ShapeMixin:
|
||||
device = next(self.parameters()).device
|
||||
except StopIteration:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
x = torch.randn(self.in_shape, device=device)
|
||||
x = torch.randn(in_shape, device=device)
|
||||
# This is needed for BatchNorm shape checking
|
||||
x = torch.stack((x, x))
|
||||
|
||||
# noinspection PyCallingNonCallable
|
||||
y = self(x)
|
||||
if isinstance(y, tuple):
|
||||
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
|
||||
@ -265,7 +269,7 @@ class Splitter(nn.Module):
|
||||
self.autopad = AutoPadToShape(self._out_shape)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim
|
||||
dim = self.dim + 1 if len(self.in_shape) == (x.ndim - 1) else self.dim
|
||||
x = x.transpose(0, dim)
|
||||
n_blocks = list()
|
||||
for block_idx in range(self.n):
|
||||
|
Reference in New Issue
Block a user